From 4a18c6b30ed65b155934138af915525bcc22beb7 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Fri, 16 Jun 2023 17:16:11 +0800 Subject: [PATCH 01/61] add hetero feature selection(#4661) Signed-off-by: Yu Wu --- python/fate/components/components/__init__.py | 3 +- .../components/hetero_feature_selection.py | 183 ++++++ python/fate/components/params/__init__.py | 4 +- .../fate/components/params/_filter_param.py | 103 ++++ python/fate/ml/feature_selection/__init__.py | 16 + .../hetero_feature_selection.py | 559 ++++++++++++++++++ 6 files changed, 866 insertions(+), 2 deletions(-) create mode 100644 python/fate/components/components/hetero_feature_selection.py create mode 100644 python/fate/components/params/_filter_param.py create mode 100644 python/fate/ml/feature_selection/__init__.py create mode 100644 python/fate/ml/feature_selection/hetero_feature_selection.py diff --git a/python/fate/components/components/__init__.py b/python/fate/components/components/__init__.py index 0dbb1ab2f5..5b264f4aa3 100644 --- a/python/fate/components/components/__init__.py +++ b/python/fate/components/components/__init__.py @@ -14,9 +14,10 @@ # limitations under the License. from .evaluation import evaluation from .feature_scale import feature_scale +from .hetero_feature_selection import hetero_feature_selection from .hetero_lr import hetero_lr from .intersection import intersection from .reader import reader from .statistics import statistics -BUILDIN_COMPONENTS = [hetero_lr, reader, feature_scale, intersection, evaluation, statistics] +BUILDIN_COMPONENTS = [hetero_lr, reader, feature_scale, intersection, evaluation, statistics, hetero_feature_selection] diff --git a/python/fate/components/components/hetero_feature_selection.py b/python/fate/components/components/hetero_feature_selection.py new file mode 100644 index 0000000000..789df30c31 --- /dev/null +++ b/python/fate/components/components/hetero_feature_selection.py @@ -0,0 +1,183 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import List + +from fate.components import ( + GUEST, + HOST, + DatasetArtifact, + Input, + ModelArtifact, + Output, + Role, + cpn, + params +) + + +@cpn.component(roles=[GUEST, HOST]) +def hetero_feature_selection(ctx, role): + ... + + +@hetero_feature_selection.train() +@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[GUEST, HOST]) +@cpn.artifact("input_statistic_model", type=Input[ModelArtifact], roles=[GUEST, HOST], optional=True) +@cpn.artifact("input_binning_model", type=Input[ModelArtifact], roles=[GUEST, HOST], optional=True) +@cpn.parameter("method", type=List[params.string_choice(["manual", "binning", "statistic"])], + default=["manual"], optional=False, + desc="selection method, options: {manual, binning, statistic}") +@cpn.parameter("select_col", type=List[str], default=None, + desc="list of column names to be selected, if None, all columns will be considered") +@cpn.parameter("iv_param", type=params.iv_filter_param(), + default=params.IVFilterParam(metrics="iv", take_high=True, + threshold=1, filter_type="threshold", host_thresholds=1, + host_take_high=True, + select_federated=True), + desc="binning filter param") +@cpn.parameter("statistic_param", type=params.statistic_filter_param(), + default=params.StatisticFilterParam(metrics="mean", + threshold=1, filter_type="threshold", take_high=True), + desc="statistic filter param") +@cpn.parameter("manual_param", type=params.manual_filter_param(), + default=params.ManualFilterParam(filter_out_col=[], keep_col=[]), + desc="note that manual filter will always be processed as the last filter") +@cpn.parameter("keep_one", type=bool, default=True, desc="whether to keep at least one feature among `select_col`") +@cpn.parameter("use_anonymous", type=bool, default=False, + desc="bool, whether interpret `select_col` & `filter_out_col` & `keep_col` as anonymous column names") +@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) +@cpn.artifact("output_model", type=Output[ModelArtifact], roles=[GUEST, HOST]) +def feature_selection_train( + ctx, + role: Role, + train_data, + input_statistic_model, + input_binning_model, + method, + select_col, + iv_param, + statistic_param, + manual_param, + keep_one, + use_anonymous, + train_output_data, + output_model, +): + train(ctx, role, train_data, train_output_data, input_binning_model, input_statistic_model, + output_model, method, select_col, iv_param, statistic_param, manual_param, + keep_one, use_anonymous) + + +@hetero_feature_selection.predict() +@cpn.artifact("input_model", type=Input[ModelArtifact], roles=[GUEST, HOST]) +@cpn.artifact("test_data", type=Input[DatasetArtifact], optional=False, roles=[GUEST, HOST]) +@cpn.artifact("test_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) +def feature_selection_predict( + ctx, + role: Role, + test_data, + input_model, + test_output_data, +): + predict(ctx, input_model, test_data, test_output_data, role) + + +def train(ctx, role, train_data, train_output_data, input_binning_model, input_statistic_model, + output_model, method, select_col, iv_param, statistic_param, manual_param, + keep_one, use_anonymous): + from fate.ml.feature_selection import HeteroSelectionModuleHost, HeteroSelectionModuleGuest + + with ctx.sub_ctx("train") as sub_ctx: + isometric_model_dict = {} + if input_binning_model: + with input_binning_model as model_reader: + model = model_reader.read_model() + model_type = json.loads(model["model_meta"]).get("model_type") + if model_type != "binning": + raise ValueError(f"model type: {model_type} is not binning, but {model_type}") + isometric_model_dict["binning"] = model + if input_statistic_model: + with input_statistic_model as model_reader: + model = model_reader.read_model() + # temp code block + model_type = json.loads(model["model_meta"]).get("model_type") + if model_type != "statistic": + raise ValueError(f"model type: {model_type} is not statistic, but {model_type}") + # temp code block end + isometric_model_dict["statistic"] = model + + # logger.info(f"input model: {isometric_model_dict}") + + train_data = sub_ctx.reader(train_data).read_dataframe().data + columns = train_data.schema.columns.to_list() + if use_anonymous: + anonymous_columns = train_data.schema.anonymous_columns.to_list() + if select_col is not None: + select_col = [columns[anonymous_columns.index(col)] for col in select_col] + if manual_param.filter_out_col is not None: + filter_out_col = [columns[anonymous_columns.index(col)] for col in manual_param.filter_out_col] + manual_param.filter_out_col = filter_out_col + if manual_param.keep_col is not None: + keep_col = [columns[anonymous_columns.index(col)] for col in manual_param.keep_col] + manual_param.keep_col = keep_col + + if role.is_guest: + selection = HeteroSelectionModuleGuest(method, select_col, isometric_model_dict, + iv_param, statistic_param, manual_param, + keep_one) + elif role.is_host: + selection = HeteroSelectionModuleHost(method, select_col, isometric_model_dict, + iv_param, statistic_param, manual_param, + keep_one) + selection.fit(sub_ctx, train_data) + model = selection.to_model() + with output_model as model_writer: + model_writer.write_model("feature_selection", model, metadata={"method": method}) + + with ctx.sub_ctx("predict") as sub_ctx: + output_data = train_data + if method is not None: + output_data = selection.transform(sub_ctx, train_data) + sub_ctx.writer(train_output_data).write_dataframe(output_data) + + +def predict(ctx, input_model, test_data, test_output_data, role): + from fate.ml.feature_selection import HeteroSelectionModuleHost, HeteroSelectionModuleGuest + + with ctx.sub_ctx("predict") as sub_ctx: + with input_model as model_reader: + model = model_reader.read_model() + if role.is_guest: + selection = HeteroSelectionModuleGuest.from_model(model) + elif role.is_host: + selection = HeteroSelectionModuleHost.from_model(model) + + model_meta = model["meta_data"] + method = model_meta["method"] + selection.method = method + test_data = sub_ctx.reader(test_data).read_dataframe().data + + output_data = test_data + if method is not None: + output_data = selection.transform(sub_ctx, test_data) + """ + # temp code start + test_data = sub_ctx.reader(test_data).read_dataframe().data + output_data = selection.transform(sub_ctx, test_data) + # temp code end + """ + sub_ctx.writer(test_output_data).write_dataframe(output_data) diff --git a/python/fate/components/params/__init__.py b/python/fate/components/params/__init__.py index d18accfb61..f159a29b48 100644 --- a/python/fate/components/params/__init__.py +++ b/python/fate/components/params/__init__.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The FATE Authors. All Rights Reserved. +# Copyright 2023 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ from ._cipher import CipherParamType, PaillierCipherParam from ._fields import Parameter, confloat, conint, jsonschema, parse, string_choice +from ._filter_param import StatisticFilterParam, IVFilterParam, ManualFilterParam, \ + statistic_filter_param, iv_filter_param, manual_filter_param from ._learning_rate import learning_rate_param from ._metrics import metrics_param, statistic_metrics_param from ._optimizer import optimizer_param diff --git a/python/fate/components/params/_filter_param.py b/python/fate/components/params/_filter_param.py new file mode 100644 index 0000000000..6ea875fd74 --- /dev/null +++ b/python/fate/components/params/_filter_param.py @@ -0,0 +1,103 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, List + +import pydantic + +from ._fields import string_choice, Parameter +from ._metrics import statistic_metrics_param + + +class StandardFilterParam(pydantic.BaseModel, Parameter): + metrics: List[str] + + filter_type: List[string_choice({'threshold', 'top_k', 'top_percentile'})] = ['threshold'] + threshold: List[Union[int, float]] = [1.0] + take_high: List[bool] = [True] + + @pydantic.validator('metrics', 'filter_type', 'threshold', 'take_high', pre=True, allow_reuse=True) + def to_list(cls, v): + return v if isinstance(v, list) else [v] + + @pydantic.root_validator(pre=False) + def check_filter_param_length(cls, values): + max_length = max([len(x) for k, x in values.items()]) + for k, v in values.items(): + if len(v) == 1: + v *= max_length + assert len(v) == max_length, f"Length of {k}: {v} does not match " \ + f"max length {max_length} of (metrics, filter_type, threshold, take_high)." + return values + + +class FederatedStandardFilterParam(StandardFilterParam, Parameter): + host_filter_type: List[string_choice({'threshold', 'top_k', 'top_percentile'})] = ['threshold'] + host_threshold: List[Union[int, float]] = [1.0] + host_take_high: List[bool] = [True] + + select_federated: bool = True + + @pydantic.validator('host_filter_type', 'host_threshold', 'host_take_high', pre=True, allow_reuse=True) + def to_list(cls, v): + return v if isinstance(v, list) else [v] + + @pydantic.root_validator(pre=False) + def check_filter_param_length(cls, values): + select_values = {k: v for k, v in values.items() if k != 'select_federated'} + max_length = max([len(x) for k, x in select_values.items()]) + for k, v in select_values.items(): + if len(v) == 1: + v *= max_length + assert len(v) == max_length, f"Length of {k}: {v} does not match " \ + f"max length {max_length} of (metrics, filter_type, threshold, take_high)." + return values + + +class IVFilterParam(FederatedStandardFilterParam, Parameter): + metrics: List[string_choice({'iv'})] = ['iv'] + + +class StatisticFilterParam(StandardFilterParam, Parameter): + metrics: List[statistic_metrics_param(describe=False)] = ["mean"] + + +class ManualFilterParam(pydantic.BaseModel, Parameter): + keep_col: List[str] = [] + left_out_col: List[str] = [] + + @pydantic.root_validator(pre=False) + def no_intersection(cls, values): + left_out_col = values.get('left_out_col', []) + keep_col = values.get('keep_col', []) + intersection = set(left_out_col).intersection(set(keep_col)) + if intersection: + raise ValueError(f"`keep_col` and `left_out_col` share common elements: {intersection}") + return values + + +def iv_filter_param(): + namespace = {} + return type("IVFilterParam", (IVFilterParam,), namespace) + + +def statistic_filter_param(): + namespace = {} + return type("StatisticFilterParam", (StatisticFilterParam,), namespace) + + +def manual_filter_param(): + namespace = {} + return type("ManualFilterParam", (ManualFilterParam,), namespace) diff --git a/python/fate/ml/feature_selection/__init__.py b/python/fate/ml/feature_selection/__init__.py new file mode 100644 index 0000000000..3eb1f74a6b --- /dev/null +++ b/python/fate/ml/feature_selection/__init__.py @@ -0,0 +1,16 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .hetero_feature_selection import HeteroSelectionModuleHost, HeteroSelectionModuleGuest diff --git a/python/fate/ml/feature_selection/hetero_feature_selection.py b/python/fate/ml/feature_selection/hetero_feature_selection.py new file mode 100644 index 0000000000..e9a2f7feb4 --- /dev/null +++ b/python/fate/ml/feature_selection/hetero_feature_selection.py @@ -0,0 +1,559 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import logging +import random + +import numpy as np +import pandas as pd + +from fate.interface import Context +from ..abc.module import Module, HeteroModule + +logger = logging.getLogger(__name__) + +DEFAULT_METRIC = {"iv": ["iv"], "statistic": ["mean"]} + + +class HeteroSelectionModuleGuest(HeteroModule): + def __init__(self, method=None, select_col=None, isometric_model_dict=None, + iv_param=None, statistic_param=None, manual_param=None, + keep_one=True): + self.method = method + self.select_col = select_col + self.isometric_model_dict = isometric_model_dict + self.iv_param = iv_param + self.statistic_param = statistic_param + self.manual_param = manual_param + self.keep_one = keep_one + # keep selection history + self._inner_method = [] + self._selection_obj = [] + + def fit(self, ctx: Context, train_data, validate_data=None) -> None: + logger.info(f"isometric_model_dict: {self.isometric_model_dict}") + if self.select_col is None: + self.select_col = train_data.schema.columns.to_list() + + select_data = train_data[self.select_col] + header = select_data.schema.columns.to_list() + for i, filter_type in enumerate(self.method): + if filter_type == "manual": + selection_obj = ManualSelection(method=filter_type, + header=header, + param=self.manual_param, + keep_one=self.keep_one) + elif filter_type == "iv": + model = self.isometric_model_dict.get("binning", None) + if model is None: + raise ValueError(f"Cannot find binning model in input, please check") + selection_obj = StandardSelection(method=filter_type, + header=header, + param=self.iv_param, + model=model, + keep_one=self.keep_one) + elif filter_type == "statistic": + model = self.isometric_model_dict.get("statistic", None) + if model is None: + raise ValueError(f"Cannot find statistic model in input, please check") + selection_obj = StandardSelection(method=filter_type, + header=header, + param=self.statistic_param, + model=model, + keep_one=self.keep_one) + else: + raise ValueError(f"{filter_type} selection method not supported, please check") + self._selection_obj.append(selection_obj) + self._inner_method.append(filter_type) + + prev_selection_obj = None + for method, selection_obj in zip(self._inner_method, self._selection_obj): + if prev_selection_obj: + selection_obj.set_prev_selected_mask(copy.deepcopy(prev_selection_obj._selected_mask)) + if isinstance(selection_obj, StandardSelection) and isinstance(prev_selection_obj, StandardSelection): + selection_obj.set_host_prev_selected_mask(copy.deepcopy(prev_selection_obj._host_selected_mask)) + selection_obj.fit(ctx, select_data) + if method == "binning": + if self.iv_param.select_federated: + HeteroSelectionModuleGuest.sync_select_federated(ctx, selection_obj) + prev_selection_obj = selection_obj + + @staticmethod + def sync_select_federated(ctx: Context, selection_obj): + logger.info(f"Sync federated selection.") + for i, host in enumerate(ctx.hosts): + federated_mask = selection_obj._host_selected_mask[host] + ctx.hosts[i].put(f"selected_mask_{selection_obj.method}", federated_mask) + + def transform(self, ctx: Context, test_data): + transformed_data = self._selection_obj[-1].transform(ctx, test_data) + return transformed_data + + def to_model(self): + # all selection obj need to be recorded for display of cascade order + selection_obj_list = [] + for selection_obj in self._selection_obj: + selection_obj_list.append(selection_obj.to_model()) + return {"selection_obj_list": json.dumps(selection_obj_list), + "method": self.method, + "select_col": self.select_col, + "inner_method": self._inner_method} + + def restore(self, model): + selection_obj_list = [] + selection_obj_model_list = json.loads(model["selection_obj_list"]) + for i, selection_model in enumerate(selection_obj_model_list): + if selection_model["method"] in ["manual"]: + selection_obj = ManualSelection(method=self._inner_method[i]) + else: + selection_obj = StandardSelection(method=self._inner_method[i]) + selection_obj.restore(selection_model) + selection_obj_list.append(selection_obj) + self._selection_obj = selection_obj_list + + @classmethod + def from_model(cls, model) -> "HeteroSelectionModuleGuest": + selection_obj = HeteroSelectionModuleGuest(model["method"], model["select_col"]) + selection_obj._inner_method = model["inner_method"] + selection_obj.restore(model) + return selection_obj + + +class HeteroSelectionModuleHost(HeteroModule): + def __init__(self, method=None, select_col=None, isometric_model_dict=None, + iv_param=None, statistic_param=None, manual_param=None, + keep_one=True): + self.method = method + self.isometric_model_dict = isometric_model_dict + self.iv_param = iv_param + self.statistic_param = statistic_param + self.manual_param = manual_param + self.keep_one = keep_one + self.select_col = select_col + # for display of cascade order + self._inner_method = [None] * len(method) + self._selection_obj = [None] * len(method) + + def fit(self, ctx: Context, train_data, validate_data=None) -> None: + if self.select_col is None: + self.select_col = train_data.schema.columns.to_list() + select_data = train_data[self.select_col] + header = select_data.schema.columns.to_list() + for i, filter_type in enumerate(self.method): + if filter_type == "manual": + selection_obj = ManualSelection(method=filter_type, + header=header, + param=self.manual_param, + keep_one=self.keep_one) + self._selection_obj[i] = selection_obj + self._inner_method[i] = "manual" + elif filter_type == "iv": + model = self.isometric_model_dict.get("binning", None) + if model is None: + raise ValueError(f"Cannot find binning model in input, please check") + selection_obj = StandardSelection(method=filter_type, + header=header, + param=self.iv_param, + model=model, + keep_one=self.keep_one) + self._selection_obj[i] = selection_obj + self._inner_method[i] = "iv" + elif filter_type == "statistic": + model = self.isometric_model_dict.get("statistic", None) + if model is None: + raise ValueError(f"Cannot find statistic model in input, please check") + selection_obj = StandardSelection(method=filter_type, + header=header, + param=self.statistic_param, + model=model, + keep_one=self.keep_one) + self._selection_obj[i] = selection_obj + self._inner_method[i] = "statistic" + else: + raise ValueError(f"{type} selection method not supported, please check") + + prev_selection_obj = None + for method, selection_obj in zip(self._inner_method, self._selection_obj): + if prev_selection_obj: + selection_obj.set_prev_selected_mask(copy.deepcopy(prev_selection_obj._selected_mask)) + selection_obj.fit(ctx, train_data, validate_data) + if method == "iv": + if self.iv_param.select_federated: + HeteroSelectionModuleHost.sync_select_federated(ctx, selection_obj, train_data) + prev_selection_obj = selection_obj + + @staticmethod + def sync_select_federated(ctx: Context, selection_obj, data): + cur_selected_mask = ctx.guest.get(f"selected_mask_{selection_obj.method}") + columns, anonymous_columns = data.schema.columns, data.schema.anonymous_columns + new_index = [columns[anonymous_columns.index(col)] for col in cur_selected_mask.index] + cur_selected_mask.index = new_index + prev_selected_mask = selection_obj._prev_selected_mask[selection_obj._prev_selected_mask] + missing_col = set(prev_selected_mask.index).difference(set(new_index)) + if missing_col: + raise ValueError( + f"results for columns: {missing_col} not found in received selection result.") + cur_selected_mask = [cur_selected_mask.get(col, False) for col in selection_obj._header] + selected_mask = selection_obj._prev_selected_mask & cur_selected_mask + selection_obj.set_selected_mask(selected_mask) + + def transform(self, ctx: Context, test_data): + transformed_data = self._selection_obj[-1].transform(ctx, test_data) + return transformed_data + + def to_model(self): + # all selection history need to be recorded for display + selection_obj_list = [] + for selection_obj in self._selection_obj: + selection_obj_list.append(selection_obj.to_model()) + return {"selection_obj_list": json.dumps(selection_obj_list), + "method": self.method, + "select_col": self.select_col, + "inner_method": self._inner_method} + + def restore(self, model): + selection_obj_list = [] + selection_obj_model_list = json.loads(model["selection_obj_list"]) + for i, selection_model in enumerate(selection_obj_model_list): + if selection_model["method"] in ["manual"]: + selection_obj = ManualSelection(method=self._inner_method[i]) + else: + selection_obj = StandardSelection(method=self._inner_method[i]) + selection_obj.restore(selection_model) + selection_obj_list.append(selection_obj) + self._selection_obj = selection_obj_list + + @classmethod + def from_model(cls, model) -> "HeteroSelectionModuleHost": + selection_obj = HeteroSelectionModuleHost(model["method"], model["select_col"]) + selection_obj._inner_method = model["inner_method"] + selection_obj.restore(model) + return selection_obj + + +class ManualSelection(Module): + def __init__(self, method, param=None, header=None, model=None, keep_one=True): + assert method == "manual", f"Manual Selection only accepts 'manual' as `method`, received {method} instead." + self.method = method + self.param = param + self.model = model + self.keep_one = keep_one + self._header = header + self._prev_selected_mask = None + if header is None: + self._selected_mask = None + else: + self._selected_mask = pd.Series(np.ones(len(header)), dtype=bool, index=header) + + def set_selected_mask(self, mask): + self._selected_mask = mask + + def set_prev_selected_mask(self, mask): + self._prev_selected_mask = mask + + def fit(self, ctx: Context, train_data, validate_data=None): + header = train_data.schema.columns.to_list() + if self._header is None: + self._header = header + self._prev_selected_mask = pd.Series(np.ones(len(header)), dtype=bool, index=header) + + filter_out_col = self.param.get("filter_out_col", None) + keep_col = self.param.get("keep_col", None) + if filter_out_col is None: + filter_out_col = [] + if keep_col is None: + keep_col = [] + if len(filter_out_col) >= len(header): + raise ValueError("`filter_out_col` should not be all columns") + filter_out_col = set(filter_out_col) + keep_col = set(keep_col) + missing_col = (filter_out_col.union(keep_col)). \ + difference(set(self._prev_selected_mask.index)) + if missing_col: + raise ValueError(f"columns {missing_col} given in `filter_out_col` & `keep_col` " + f"not found in `select_col` or header") + filter_out_mask = pd.Series([False if col in filter_out_col else True for col in self._header], + index=self._header) + # keep_mask = [True if col in keep_col else False for col in self._header] + selected_mask = self._prev_selected_mask & filter_out_mask + selected_mask.loc[keep_col] = True + self._selected_mask = selected_mask + if self.keep_one: + StandardSelection._keep_one(self._selected_mask, self._header) + + def transform(self, ctx: Context, transform_data): + logger.debug(f"Start transform") + drop_cols = set(self._selected_mask[~self._selected_mask].index) + select_cols = [col for col in transform_data.schema.columns.to_list() if col not in drop_cols] + return transform_data[select_cols] + + def to_model(self): + return dict( + method=self.method, + keep_one=self.keep_one, + selected_mask=self._selected_mask.to_dict() + ) + + def restore(self, model): + self.method = model["method"] + self.keep_one = model["keep_one"] + self._selected_mask = pd.Series(["selected_mask"], dtype=bool) + + +class StandardSelection(Module): + def __init__(self, method, header=None, param=None, model=None, keep_one=True): + self.method = method + self.param = param + self.filter_conf = {} + + if param is not None: + for metric_name, filter_type, threshold, take_high in zip( + self.param.get("metrics", DEFAULT_METRIC.get(method)), + self.param.get("filter_type", ['threshold']), + self.param.get("threshold", [1.0]), + self.param.get("take_high", [True])): + metric_conf = self.filter_conf.get(metric_name, {}) + metric_conf["filter_type"] = metric_conf.get("filter_type", []) + [filter_type] + metric_conf["threshold"] = metric_conf.get("threshold", []) + [threshold] + metric_conf["take_high"] = metric_conf.get("take_high", []) + [take_high] + self.filter_conf[metric_name] = metric_conf + # temp code block starts + """if param is not None: + for metric_name, filter_type, threshold, take_high in zip( + self.param.metrics or DEFAULT_METRIC.get(method), + self.param.filter_type or ['threshold'], + self.param.threshold or [1.0], + self.param.take_high or [True]): + metric_conf = self.filter_conf.get(metric_name, {}) + metric_conf["filter_type"] = metric_conf.get("filter_type", []) + [filter_type] + metric_conf["threshold"] = metric_conf.get("threshold", []) + [threshold] + metric_conf["take_high"] = metric_conf.get("take_high", []) + [take_high] + self.filter_conf[metric_name] = metric_conf""" + # temp code block ends + + self.model = self.convert_model(model) + self.keep_one = keep_one + self._header = header + self._selected_mask = None + self._all_selected_mask = None + if header is None: + self._prev_selected_mask = None + else: + self._prev_selected_mask = pd.Series(np.ones(len(header)), dtype=bool, index=header) + self._host_selected_mask = {} + self._all_host_selected_mask = {} + self._host_prev_selected_mask = {} + self._all_metrics = None + self._all_host_metrics = {} + + @staticmethod + def convert_model(input_model): + return input_model + + def set_host_prev_selected_mask(self, mask): + self._host_prev_selected_mask = mask + + def set_prev_selected_mask(self, mask): + self._prev_selected_mask = mask + + def fit(self, ctx: Context, train_data, validate_data=None): + if self._header is None: + header = train_data.schema.columns.to_list() + self._header = header + self._prev_selected_mask = pd.Series(np.ones(len(header)), dtype=bool, index=header) + + metric_names = self.param.get("metrics", []) + # temp code bock start + """metric_names = self.param.metrics or []""" + # temp code ends + # local only + if self.method in ["statistic"]: + for metric_name in metric_names: + if metric_name not in self.model.get("metrics", {}): + raise ValueError(f"metric {metric_name} not found in given statistic model with metrics: " + f"{metric_names}, please check") + + metrics_all = pd.DataFrame(self.model.get("metrics_summary", {})).loc[metric_names] + self._all_metrics = metrics_all + missing_col = set(self._prev_selected_mask[self._prev_selected_mask].index). \ + difference(set(metrics_all.columns)) + if missing_col: + raise ValueError( + f"metrics for columns {missing_col} from `select_col` or header not found in given model.") + + """ mask_all = metrics_all.apply(lambda r: StandardSelection.filter_multiple_metrics(r, + self.param.filter_type, + self.param.threshold, + self.param.take_high, + metric_names), axis=1)""" + mask_all = self.apply_filter(metrics_all, self.filter_conf) + self._all_selected_mask = mask_all + cur_selected_mask = mask_all.all(axis=0) + cur_selected_mask = [cur_selected_mask.get(col, False) for col in self._header] + self._selected_mask = self._prev_selected_mask & cur_selected_mask + if self.keep_one: + self._keep_one(self._selected_mask, self._prev_selected_mask, self._header) + # federated selection possible + elif self.method == "iv": + # host does not perform local iv selection + if ctx.local[0] == "host": + return + iv_metrics = pd.Series(self.model["metrics_summary"]["iv"]) + metrics_all = pd.DataFrame(iv_metrics).T.rename({0: "iv"}, axis=0) + self._all_metrics = metrics_all + # works for multiple iv filters + """mask_all = metrics_all.apply(lambda r: StandardSelection.filter_multiple_metrics(r, + self.param.filter_type, + self.param.threshold, + self.param.take_high, + metric_names), axis=1) + """ + mask_all = self.apply_filter(metrics_all, self.filter_conf) + self._all_selected_mask = mask_all + cur_selected_mask = mask_all.all(axis=0) + cur_selected_mask = [cur_selected_mask.get(col, False) for col in self._header] + self._selected_mask = self._prev_selected_mask & cur_selected_mask + if self.keep_one: + self._keep_one(self._selected_mask, self._prev_selected_mask, self._header) + if self.param.get("select_federated", True): + host_metrics_summary = self.model["host_train_metrics_summary"] + for host, host_metrics in host_metrics_summary.items(): + iv_metrics = pd.Series(host_metrics["iv"]) + metrics_all = pd.DataFrame(iv_metrics).T.rename({0: "iv"}, axis=0) + self._all_host_metrics[host] = metrics_all + """host_mask_all = metrics_all.apply(lambda r: + StandardSelection.filter_multiple_metrics(r, + self.param.host_filter_type, + self.param.threshold, + self.param.take_high, + metric_names), axis=1) + """ + host_mask_all = self.apply_filter(metrics_all, + self.filter_conf) + self._all_host_selected_mask[host] = host_mask_all + """host_prev_selected_mask = self._host_prev_selected_mask.get(host) + if host_prev_selected_mask is None: + host_prev_selected_mask = pd.Series(np.ones(len(iv_metrics.index)), + index=iv_metrics.index) + self._host_prev_selected_mask[host] = host_prev_selected_mask""" + + host_selected_mask = host_mask_all.all(axis=0) + if self.keep_one: + self._keep_one(host_selected_mask) + self._host_selected_mask[host] = host_selected_mask + + @staticmethod + def _keep_one(cur_mask, prev_mask=None, select_col=None): + if sum(cur_mask) > 0: + return cur_mask + else: + if prev_mask is not None: + idx = random.choice(prev_mask[prev_mask].index) + elif select_col is not None: + idx = random.choice(select_col) + else: + idx = random.choice(cur_mask.index) + cur_mask[idx] = True + + @staticmethod + def convert_series_metric_to_dataframe(metrics, metric_name): + return pd.DataFrame(metrics).T.rename({0: metric_name}, axis=0) + + @staticmethod + def apply_filter(metrics_all, filter_conf): + return metrics_all.apply(lambda r: + StandardSelection.filter_multiple_metrics(r, + filter_conf[r.name]), + axis=1) + + @staticmethod + def filter_multiple_metrics(metrics, metric_conf): + filter_type_list = metric_conf["filter_type"] + threshold_list = metric_conf["threshold"] + take_high_list = metric_conf["take_high"] + result = pd.Series(np.ones(len(metrics.index)), index=metrics.index, dtype=bool) + for idx in range(len(filter_type_list)): + result &= StandardSelection.filter_metrics(metrics, + filter_type_list[idx], + threshold_list[idx], + take_high_list[idx]) + return result + + @staticmethod + def filter_metrics(metrics, filter_type, threshold, take_high=True): + if filter_type == "top_k": + return StandardSelection.filter_by_top_k(metrics, threshold, take_high) + elif filter_type == "threshold": + return StandardSelection.filter_by_threshold(metrics, threshold, take_high) + elif filter_type == "percentile": + return StandardSelection.filter_by_percentile(metrics, threshold, take_high) + else: + raise ValueError(f"filter_type {filter_type} not supported, please check") + + @staticmethod + def filter_by_top_k(metrics, k, take_high=True): + # strict top k + if k == 0: + return pd.Series(np.ones(len(metrics)), dtype=bool) + # stable sort + ordered_metrics = metrics.sort_values(ascending=not take_high, kind="mergesort") + select_k = ordered_metrics.index[:k] + return metrics.index.isin(select_k) + + @staticmethod + def filter_by_threshold(metrics, threshold, take_high=True): + if take_high: + return metrics >= threshold + else: + return metrics <= threshold + + @staticmethod + def filter_by_percentile(metrics, percentile, take_high=True): + if take_high: + return metrics >= metrics.quantile(percentile) + else: + return metrics <= metrics.quantile(1 - percentile) + + def transform(self, ctx: Context, transform_data): + logger.debug(f"Start transform") + drop_cols = set(self._selected_mask[~self._selected_mask].index) + cols = transform_data.schema.columns.to_list() + select_cols = [col for col in cols if col not in drop_cols] + return transform_data[select_cols] + + def to_model(self): + return dict( + method=self.method, + keep_one=self.keep_one, + all_selected_mask=self._all_selected_mask.to_dict(), + all_metrics=self._all_metrics.to_dict(), + all_host_metrics={k: v.to_dict() for k, v in self._all_host_metrics.items()}, + selected_mask=self._selected_mask.to_dict(), + host_selected_mask={k: v.to_dict() for k, v in self._host_selected_mask.items()}, + all_host_selected_mask={k: v.to_dict() for k, v in self._all_host_selected_mask.items()}, + ) + + def restore(self, model): + self.method = model["method"] + self.keep_one = model["keep_one"] + self._selected_mask = pd.Series(["selected_mask"], dtype=bool) + self._all_selected_mask = pd.DataFrame(model["all_selected_mask"], dtype=bool) + self._all_metrics = pd.DataFrame(model["all_metrics"]) + self._host_selected_mask = {k: pd.Series(v, dtype=bool) for k, v in model["host_selected_mask"].items()} + self._all_host_selected_mask = {k: pd.DataFrame(v, dtype=bool) for + k, v in model["all_host_selected_mask"].items()} + self._all_host_metrics = {k: pd.DataFrame(v) for k, v in model["all_host_metrics"].items()} From 8eb58943042255f0e03980abda425dab7f9c660d Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Fri, 16 Jun 2023 17:46:50 +0800 Subject: [PATCH 02/61] fix hetero feature selection cpn import (#4661) add hetero feature selection example Signed-off-by: Yu Wu --- examples/pipeline/test_selection.py | 178 ++++++++++++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 examples/pipeline/test_selection.py diff --git a/examples/pipeline/test_selection.py b/examples/pipeline/test_selection.py new file mode 100644 index 0000000000..e6044d53fc --- /dev/null +++ b/examples/pipeline/test_selection.py @@ -0,0 +1,178 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json + +from fate_client.pipeline import StandalonePipeline, FateFlowPipeline +from fate_client.pipeline.components.fate import FeatureScale +from fate_client.pipeline.components.fate import Intersection +from fate_client.pipeline.components.fate import Reader +from fate_client.pipeline.components.fate import Statistics, HeteroFeatureSelection +from fate_client.pipeline.utils import test_utils + + +def main(config="./config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + if config.work_mode == 0: + pipeline = StandalonePipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + else: + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + reader_0 = Reader(name="reader_0") + cluster = config.work_mode + + if cluster: + reader_0.guest.component_param(table_name="breast_hetero_guest", + namespace=f"{namespace}experiment", + # path="file:///data/projects/fate/examples/data/breast_hetero_guest.csv", + # format="csv", + # match_id_name="id", + # delimiter=",", + label_name="y", + label_type="float32", + dtype="float32") + + reader_0.hosts[0].component_param(table_name="breast_hetero_host", + namespace=f"{namespace}experiment", + # path="file:///data/projects/fate/examples/data/breast_hetero_host.csv", + # match_id_name="id", + # delimiter=",", + label_name=None, + dtype="float32") + else: + data_base = config.data_base_dir + + reader_0.guest.component_param(path=f"file://{data_base}/examples/data/breast_hetero_guest.csv", + # path="file:///data/projects/fate/examples/data/breast_hetero_guest.csv", + format="csv", + match_id_name="id", + delimiter=",", + label_name="y", + label_type="float32", + dtype="float32") + + reader_0.hosts[0].component_param(path=f"file://{data_base}/examples/data/breast_hetero_host.csv", + # path="file:///data/projects/fate/examples/data/breast_hetero_host.csv", + format="csv", + match_id_name="id", + delimiter=",", + label_name=None, + dtype="float32") + + intersection_0 = Intersection(name="intersection_0", + method="raw", + input_data=reader_0.outputs["output_data"]) + + intersection_1 = Intersection(name="intersection_1", + method="raw", + input_data=reader_0.outputs["output_data"]) + + feature_scale_0 = FeatureScale(name="feature_scale_0", + method="standard", + train_data=intersection_0.outputs["output_data"]) + + feature_scale_1 = FeatureScale(name="feature_scale_1", + test_data=intersection_1.outputs["output_data"], + input_model=feature_scale_0.outputs["output_model"]) + + statistics_0 = Statistics(name="statistics_0", train_data=feature_scale_1.outputs["test_output_data"], + metrics=["mean", "max", "std", "var", "kurtosis", "skewness"]) + + selection_0 = HeteroFeatureSelection(name="selection_0", train_data=intersection_0.outputs["output_data"], + method=["statistic"], + input_statistic_model=statistics_0.outputs["output_model"], + statistic_param={"metrics": ["mean", "max", "kurtosis", "skewness"]}) + + pipeline.add_task(reader_0) + pipeline.add_task(feature_scale_0) + pipeline.add_task(feature_scale_1) + pipeline.add_task(intersection_0) + pipeline.add_task(intersection_1) + pipeline.add_task(statistics_0) + pipeline.add_task(selection_0) + pipeline.compile() + print(pipeline.get_dag()) + pipeline.fit() + print(json.dumps(pipeline.get_task_info("statistics_0").get_output_model(), indent=4)) + + print(json.dumps(pipeline.get_task_info("selection_0").get_output_model(), indent=4)) + + predict_pipeline = StandalonePipeline() + reader_1 = Reader(name="reader_1") + if cluster: + reader_1.guest.component_param(table_name="breast_hetero_guest", + namespace=f"{namespace}experiment", + # path="file:///data/projects/fate/examples/data/breast_hetero_guest.csv", + # format="csv", + # match_id_name="id", + # delimiter=",", + label_name="y", + label_type="float32", + dtype="float32") + + reader_1.hosts[0].component_param(table_name="breast_hetero_host", + namespace=f"{namespace}experiment", + # path="file:///data/projects/fate/examples/data/breast_hetero_host.csv", + # match_id_name="id", + # delimiter=",", + label_name=None, + dtype="float32") + else: + data_base = config.data_base_dir + + reader_1.guest.component_param(path=f"file://{data_base}/examples/data/breast_hetero_guest.csv", + # path="file:///data/projects/fate/examples/data/breast_hetero_guest.csv", + format="csv", + match_id_name="id", + delimiter=",", + label_name="y", + label_type="float32", + dtype="float32") + + reader_1.hosts[0].component_param(path=f"file://{data_base}/examples/data/breast_hetero_host.csv", + # path="file:///data/projects/fate/examples/data/breast_hetero_host.csv", + format="csv", + match_id_name="id", + delimiter=",", + label_name=None, + dtype="float32") + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.intersection_0.input_data = reader_1.outputs["output_data"] + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.add_task(reader_1) + + print("\n\n\n") + print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + + print(predict_pipeline.get_task_info("selection_0").get_output_data()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("-config", type=str, default="", + help="config file") + parser.add_argument("-namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) From 77ddaa6652994cdc40e8a6bf6173061ef0b4f792 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Tue, 4 Jul 2023 19:11:33 +0800 Subject: [PATCH 03/61] edit selection(#4661) Signed-off-by: Yu Wu --- examples/pipeline/test_upload.py | 65 ++++++++++++------- python/fate/components/components/__init__.py | 6 ++ .../{ => core}/params/_filter_param.py | 0 3 files changed, 48 insertions(+), 23 deletions(-) rename python/fate/components/{ => core}/params/_filter_param.py (100%) diff --git a/examples/pipeline/test_upload.py b/examples/pipeline/test_upload.py index 6f72f0afba..c44261de8a 100644 --- a/examples/pipeline/test_upload.py +++ b/examples/pipeline/test_upload.py @@ -14,27 +14,46 @@ # limitations under the License. from fate_client.pipeline import FateFlowPipeline -pipeline = FateFlowPipeline() -pipeline.upload(file="${abs_path_of_data_guest}", - # file="/data/projects/fate/examples/data/breast_hetero_guest.csv", - head=1, - partitions=4, - namespace="experiment", - name="breast_hetero_guest", - meta={ - "label_name": "y", - "label_type": "float32", - "dtype": "float32" - }) +pipeline = FateFlowPipeline().set_roles( + local="0") +pipeline.set_site_role("local") +pipeline.set_site_party_id("0") +meta = {'delimiter': ',', + 'dtype': 'float32', + 'input_format': 'dense', + 'label_type': 'int32', + 'label_name': 'y', + 'match_id_name': 'id', + 'match_id_range': 0, + 'sample_id_name': 'id', + 'tag_value_delimiter': ':', + 'tag_with_value': False, + 'weight_type': 'float32'} -pipeline = FateFlowPipeline() -pipeline.upload(file="${abs_path_of_data_host}", - # file="/data/projects/fate/examples/data/breast_hetero_host.csv", - head=1, - partitions=4, - namespace="experiment", - name="breast_hetero_host", - meta={ - "label_name": None, - "dtype": "float32" - }) \ No newline at end of file +pipeline.transform_local_file_to_dataframe( # file="${abs_path_of_data_guest}", + file="/Users/yuwu/PycharmProjects/FATE/examples/data/breast_hetero_guest.csv", + meta=meta, head=True, + namespace="experiment", + name="breast_hetero_guest") + +meta = {'delimiter': ',', + 'dtype': 'float32', + 'input_format': 'dense', + 'label_type': 'int', + 'match_id_name': 'id', + 'match_id_range': 0, + 'sample_id_name': 'id', + 'tag_value_delimiter': ':', + 'tag_with_value': False, + 'weight_type': 'float32'} + +pipeline = FateFlowPipeline().set_roles( + local="0") +pipeline.set_site_role("local") +pipeline.set_site_party_id("0") + +pipeline.transform_local_file_to_dataframe( # file="${abs_path_of_data_host}", + file="/Users/yuwu/PycharmProjects/FATE/examples/data/breast_hetero_host.csv", + meta=meta, head=True, + namespace="experiment", + name="breast_hetero_host") diff --git a/python/fate/components/components/__init__.py b/python/fate/components/components/__init__.py index 67b05219f4..a1987e58d7 100644 --- a/python/fate/components/components/__init__.py +++ b/python/fate/components/components/__init__.py @@ -92,6 +92,12 @@ def statistics(self): return statistics + @_lazy_cpn + def hetero_feature_selection(self): + from .hetero_feature_selection import hetero_feature_selection + + return hetero_feature_selection + @_lazy_cpn def dataframe_io_test(self): from .dataframe_io_test import dataframe_io_test diff --git a/python/fate/components/params/_filter_param.py b/python/fate/components/core/params/_filter_param.py similarity index 100% rename from python/fate/components/params/_filter_param.py rename to python/fate/components/core/params/_filter_param.py From 37d94dc38a94f991e66dd1c630939f86f8dd0ccd Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Tue, 4 Jul 2023 19:11:50 +0800 Subject: [PATCH 04/61] edit selection(#4661) Signed-off-by: Yu Wu --- .../components/hetero_feature_selection.py | 246 +++++++----------- 1 file changed, 98 insertions(+), 148 deletions(-) diff --git a/python/fate/components/components/hetero_feature_selection.py b/python/fate/components/components/hetero_feature_selection.py index 789df30c31..bfb169331e 100644 --- a/python/fate/components/components/hetero_feature_selection.py +++ b/python/fate/components/components/hetero_feature_selection.py @@ -13,20 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from typing import List -from fate.components import ( - GUEST, - HOST, - DatasetArtifact, - Input, - ModelArtifact, - Output, - Role, - cpn, - params -) +from fate.arch import Context +from fate.components.core import GUEST, HOST, Role, cpn, params @cpn.component(roles=[GUEST, HOST]) @@ -35,149 +25,109 @@ def hetero_feature_selection(ctx, role): @hetero_feature_selection.train() -@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[GUEST, HOST]) -@cpn.artifact("input_statistic_model", type=Input[ModelArtifact], roles=[GUEST, HOST], optional=True) -@cpn.artifact("input_binning_model", type=Input[ModelArtifact], roles=[GUEST, HOST], optional=True) -@cpn.parameter("method", type=List[params.string_choice(["manual", "binning", "statistic"])], - default=["manual"], optional=False, - desc="selection method, options: {manual, binning, statistic}") -@cpn.parameter("select_col", type=List[str], default=None, - desc="list of column names to be selected, if None, all columns will be considered") -@cpn.parameter("iv_param", type=params.iv_filter_param(), - default=params.IVFilterParam(metrics="iv", take_high=True, - threshold=1, filter_type="threshold", host_thresholds=1, - host_take_high=True, - select_federated=True), - desc="binning filter param") -@cpn.parameter("statistic_param", type=params.statistic_filter_param(), - default=params.StatisticFilterParam(metrics="mean", - threshold=1, filter_type="threshold", take_high=True), - desc="statistic filter param") -@cpn.parameter("manual_param", type=params.manual_filter_param(), - default=params.ManualFilterParam(filter_out_col=[], keep_col=[]), - desc="note that manual filter will always be processed as the last filter") -@cpn.parameter("keep_one", type=bool, default=True, desc="whether to keep at least one feature among `select_col`") -@cpn.parameter("use_anonymous", type=bool, default=False, - desc="bool, whether interpret `select_col` & `filter_out_col` & `keep_col` as anonymous column names") -@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) -@cpn.artifact("output_model", type=Output[ModelArtifact], roles=[GUEST, HOST]) -def feature_selection_train( - ctx, +def train( + ctx: Context, role: Role, - train_data, - input_statistic_model, - input_binning_model, - method, - select_col, - iv_param, - statistic_param, - manual_param, - keep_one, - use_anonymous, - train_output_data, - output_model, + train_data: cpn.dataframe_input(roles=[GUEST, HOST]), + input_models: cpn.json_model_inputs(roles=[GUEST, HOST]), + method: cpn.parameter(type=List[params.string_choice(["manual", "binning", "statistic"])], + default=["manual"], optional=False, + desc="selection method, options: {manual, binning, statistic}"), + select_col: cpn.parameter(type=List[str], default=None, + desc="list of column names to be selected, if None, all columns will be considered"), + iv_param: cpn.parameter(type=params.iv_filter_param(), + default=params.IVFilterParam(metrics="iv", take_high=True, + threshold=1, filter_type="threshold", host_thresholds=1, + host_take_high=True, + select_federated=True), + desc="iv filter param"), + statistic_param: cpn.parameter(type=params.statistic_filter_param(), + default=params.StatisticFilterParam(metrics="mean", + threshold=1, + filter_type="threshold", + take_high=True), + desc="statistic filter param"), + manual_param: cpn.parameter(type=params.manual_filter_param(), + default=params.ManualFilterParam(filter_out_col=[], keep_col=[]), + desc="note that manual filter will always be processed as the last filter"), + keep_one: cpn.parameter(type=bool, + default=True, + desc="whether to keep at least one feature among `select_col`"), + use_anonymous: cpn.parameter(type=bool, default=False, + desc="bool, whether interpret `select_col` & `filter_out_col` & `keep_col` " + "as anonymous column names"), + train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), + train_output_model: cpn.json_model_output(roles=[GUEST, HOST]) ): - train(ctx, role, train_data, train_output_data, input_binning_model, input_statistic_model, - output_model, method, select_col, iv_param, statistic_param, manual_param, - keep_one, use_anonymous) + from fate.ml.feature_selection import HeteroSelectionModuleHost, HeteroSelectionModuleGuest + + sub_ctx = ctx.sub_ctx("train") + isometric_model_dict = {} + for model in input_models: + model_type = model.artifact.metadata.metadata + model = model.read() + isometric_model_dict[model_type] = model + + train_data = train_data.read() + columns = train_data.schema.columns.to_list() + if use_anonymous: + anonymous_columns = train_data.schema.anonymous_columns.to_list() + if select_col is not None: + select_col = [columns[anonymous_columns.index(col)] for col in select_col] + if manual_param.filter_out_col is not None: + filter_out_col = [columns[anonymous_columns.index(col)] for col in manual_param.filter_out_col] + manual_param.filter_out_col = filter_out_col + if manual_param.keep_col is not None: + keep_col = [columns[anonymous_columns.index(col)] for col in manual_param.keep_col] + manual_param.keep_col = keep_col + + if role.is_guest: + selection = HeteroSelectionModuleGuest(method, select_col, isometric_model_dict, + iv_param, statistic_param, manual_param, + keep_one) + elif role.is_host: + selection = HeteroSelectionModuleHost(method, select_col, isometric_model_dict, + iv_param, statistic_param, manual_param, + keep_one) + else: + raise ValueError(f"role: {role} is not valid") + selection.fit(sub_ctx, train_data) + model = selection.to_model() + train_output_model.write(model, metadata={"method": method}) + + sub_ctx = ctx.sub_ctx("predict") + output_data = train_data + if method is not None: + output_data = selection.transform(sub_ctx, train_data) + train_output_data.write(output_data) @hetero_feature_selection.predict() -@cpn.artifact("input_model", type=Input[ModelArtifact], roles=[GUEST, HOST]) -@cpn.artifact("test_data", type=Input[DatasetArtifact], optional=False, roles=[GUEST, HOST]) -@cpn.artifact("test_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) -def feature_selection_predict( - ctx, +def predict( + ctx: Context, role: Role, - test_data, - input_model, - test_output_data, + test_data: cpn.dataframe_input(roles=[GUEST, HOST]), + input_model: cpn.json_model_input(roles=[GUEST, HOST]), + test_output_data: cpn.dataframe_output(roles=[GUEST, HOST]) ): - predict(ctx, input_model, test_data, test_output_data, role) - - -def train(ctx, role, train_data, train_output_data, input_binning_model, input_statistic_model, - output_model, method, select_col, iv_param, statistic_param, manual_param, - keep_one, use_anonymous): - from fate.ml.feature_selection import HeteroSelectionModuleHost, HeteroSelectionModuleGuest - - with ctx.sub_ctx("train") as sub_ctx: - isometric_model_dict = {} - if input_binning_model: - with input_binning_model as model_reader: - model = model_reader.read_model() - model_type = json.loads(model["model_meta"]).get("model_type") - if model_type != "binning": - raise ValueError(f"model type: {model_type} is not binning, but {model_type}") - isometric_model_dict["binning"] = model - if input_statistic_model: - with input_statistic_model as model_reader: - model = model_reader.read_model() - # temp code block - model_type = json.loads(model["model_meta"]).get("model_type") - if model_type != "statistic": - raise ValueError(f"model type: {model_type} is not statistic, but {model_type}") - # temp code block end - isometric_model_dict["statistic"] = model - - # logger.info(f"input model: {isometric_model_dict}") - - train_data = sub_ctx.reader(train_data).read_dataframe().data - columns = train_data.schema.columns.to_list() - if use_anonymous: - anonymous_columns = train_data.schema.anonymous_columns.to_list() - if select_col is not None: - select_col = [columns[anonymous_columns.index(col)] for col in select_col] - if manual_param.filter_out_col is not None: - filter_out_col = [columns[anonymous_columns.index(col)] for col in manual_param.filter_out_col] - manual_param.filter_out_col = filter_out_col - if manual_param.keep_col is not None: - keep_col = [columns[anonymous_columns.index(col)] for col in manual_param.keep_col] - manual_param.keep_col = keep_col - - if role.is_guest: - selection = HeteroSelectionModuleGuest(method, select_col, isometric_model_dict, - iv_param, statistic_param, manual_param, - keep_one) - elif role.is_host: - selection = HeteroSelectionModuleHost(method, select_col, isometric_model_dict, - iv_param, statistic_param, manual_param, - keep_one) - selection.fit(sub_ctx, train_data) - model = selection.to_model() - with output_model as model_writer: - model_writer.write_model("feature_selection", model, metadata={"method": method}) - - with ctx.sub_ctx("predict") as sub_ctx: - output_data = train_data - if method is not None: - output_data = selection.transform(sub_ctx, train_data) - sub_ctx.writer(train_output_data).write_dataframe(output_data) - - -def predict(ctx, input_model, test_data, test_output_data, role): from fate.ml.feature_selection import HeteroSelectionModuleHost, HeteroSelectionModuleGuest - with ctx.sub_ctx("predict") as sub_ctx: - with input_model as model_reader: - model = model_reader.read_model() - if role.is_guest: - selection = HeteroSelectionModuleGuest.from_model(model) - elif role.is_host: - selection = HeteroSelectionModuleHost.from_model(model) - - model_meta = model["meta_data"] - method = model_meta["method"] - selection.method = method - test_data = sub_ctx.reader(test_data).read_dataframe().data - - output_data = test_data - if method is not None: - output_data = selection.transform(sub_ctx, test_data) - """ - # temp code start - test_data = sub_ctx.reader(test_data).read_dataframe().data + sub_ctx = ctx.sub_ctx("predict") + with input_model as model_reader: + model = model_reader.read_model() + if role.is_guest: + selection = HeteroSelectionModuleGuest.from_model(model) + elif role.is_host: + selection = HeteroSelectionModuleHost.from_model(model) + else: + raise ValueError(f"role: {role} is not valid") + + model_meta = model["meta_data"] + method = model_meta["method"] + selection.method = method + test_data = test_data.read() + + output_data = test_data + if method is not None: output_data = selection.transform(sub_ctx, test_data) - # temp code end - """ - sub_ctx.writer(test_output_data).write_dataframe(output_data) + test_output_data.write(output_data) From f3df357e8bebe961c0fee4056c3815731e68b50c Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Tue, 4 Jul 2023 19:48:30 +0800 Subject: [PATCH 05/61] edit selection(#4661) Signed-off-by: Yu Wu --- examples/pipeline/test_selection.py | 217 +++++------------- .../components/core/params/_filter_param.py | 2 +- 2 files changed, 59 insertions(+), 160 deletions(-) diff --git a/examples/pipeline/test_selection.py b/examples/pipeline/test_selection.py index e6044d53fc..1cfa078484 100644 --- a/examples/pipeline/test_selection.py +++ b/examples/pipeline/test_selection.py @@ -12,167 +12,66 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse import json -from fate_client.pipeline import StandalonePipeline, FateFlowPipeline +from fate_client.pipeline import FateFlowPipeline from fate_client.pipeline.components.fate import FeatureScale from fate_client.pipeline.components.fate import Intersection -from fate_client.pipeline.components.fate import Reader from fate_client.pipeline.components.fate import Statistics, HeteroFeatureSelection -from fate_client.pipeline.utils import test_utils - - -def main(config="./config.yaml", namespace=""): - if isinstance(config, str): - config = test_utils.load_job_config(config) - - parties = config.parties - guest = parties.guest[0] - host = parties.host[0] - arbiter = parties.arbiter[0] - - if config.work_mode == 0: - pipeline = StandalonePipeline().set_roles(guest=guest, host=host, arbiter=arbiter) - else: - pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) - reader_0 = Reader(name="reader_0") - cluster = config.work_mode - - if cluster: - reader_0.guest.component_param(table_name="breast_hetero_guest", - namespace=f"{namespace}experiment", - # path="file:///data/projects/fate/examples/data/breast_hetero_guest.csv", - # format="csv", - # match_id_name="id", - # delimiter=",", - label_name="y", - label_type="float32", - dtype="float32") - - reader_0.hosts[0].component_param(table_name="breast_hetero_host", - namespace=f"{namespace}experiment", - # path="file:///data/projects/fate/examples/data/breast_hetero_host.csv", - # match_id_name="id", - # delimiter=",", - label_name=None, - dtype="float32") - else: - data_base = config.data_base_dir - - reader_0.guest.component_param(path=f"file://{data_base}/examples/data/breast_hetero_guest.csv", - # path="file:///data/projects/fate/examples/data/breast_hetero_guest.csv", - format="csv", - match_id_name="id", - delimiter=",", - label_name="y", - label_type="float32", - dtype="float32") - - reader_0.hosts[0].component_param(path=f"file://{data_base}/examples/data/breast_hetero_host.csv", - # path="file:///data/projects/fate/examples/data/breast_hetero_host.csv", - format="csv", - match_id_name="id", - delimiter=",", - label_name=None, - dtype="float32") - - intersection_0 = Intersection(name="intersection_0", - method="raw", - input_data=reader_0.outputs["output_data"]) - - intersection_1 = Intersection(name="intersection_1", - method="raw", - input_data=reader_0.outputs["output_data"]) - - feature_scale_0 = FeatureScale(name="feature_scale_0", - method="standard", - train_data=intersection_0.outputs["output_data"]) - - feature_scale_1 = FeatureScale(name="feature_scale_1", - test_data=intersection_1.outputs["output_data"], - input_model=feature_scale_0.outputs["output_model"]) - - statistics_0 = Statistics(name="statistics_0", train_data=feature_scale_1.outputs["test_output_data"], - metrics=["mean", "max", "std", "var", "kurtosis", "skewness"]) - - selection_0 = HeteroFeatureSelection(name="selection_0", train_data=intersection_0.outputs["output_data"], - method=["statistic"], - input_statistic_model=statistics_0.outputs["output_model"], - statistic_param={"metrics": ["mean", "max", "kurtosis", "skewness"]}) - - pipeline.add_task(reader_0) - pipeline.add_task(feature_scale_0) - pipeline.add_task(feature_scale_1) - pipeline.add_task(intersection_0) - pipeline.add_task(intersection_1) - pipeline.add_task(statistics_0) - pipeline.add_task(selection_0) - pipeline.compile() - print(pipeline.get_dag()) - pipeline.fit() - print(json.dumps(pipeline.get_task_info("statistics_0").get_output_model(), indent=4)) - - print(json.dumps(pipeline.get_task_info("selection_0").get_output_model(), indent=4)) - - predict_pipeline = StandalonePipeline() - reader_1 = Reader(name="reader_1") - if cluster: - reader_1.guest.component_param(table_name="breast_hetero_guest", - namespace=f"{namespace}experiment", - # path="file:///data/projects/fate/examples/data/breast_hetero_guest.csv", - # format="csv", - # match_id_name="id", - # delimiter=",", - label_name="y", - label_type="float32", - dtype="float32") - - reader_1.hosts[0].component_param(table_name="breast_hetero_host", - namespace=f"{namespace}experiment", - # path="file:///data/projects/fate/examples/data/breast_hetero_host.csv", - # match_id_name="id", - # delimiter=",", - label_name=None, - dtype="float32") - else: - data_base = config.data_base_dir - - reader_1.guest.component_param(path=f"file://{data_base}/examples/data/breast_hetero_guest.csv", - # path="file:///data/projects/fate/examples/data/breast_hetero_guest.csv", - format="csv", - match_id_name="id", - delimiter=",", - label_name="y", - label_type="float32", - dtype="float32") - - reader_1.hosts[0].component_param(path=f"file://{data_base}/examples/data/breast_hetero_host.csv", - # path="file:///data/projects/fate/examples/data/breast_hetero_host.csv", - format="csv", - match_id_name="id", - delimiter=",", - label_name=None, - dtype="float32") - - deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.intersection_0.input_data = reader_1.outputs["output_data"] - - predict_pipeline.add_task(deployed_pipeline) - predict_pipeline.add_task(reader_1) - - print("\n\n\n") - print(predict_pipeline.compile().get_dag()) - predict_pipeline.predict() - - print(predict_pipeline.get_task_info("selection_0").get_output_data()) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser("PIPELINE DEMO") - parser.add_argument("-config", type=str, default="", - help="config file") - parser.add_argument("-namespace", type=str, default="", - help="namespace for data stored in FATE") - args = parser.parse_args() - main(config=args.config, namespace=args.namespace) +from fate_client.pipeline.interface import DataWarehouseChannel + +pipeline = FateFlowPipeline().set_roles(guest="9999", host="9998", arbiter="9998") + +intersection_0 = Intersection("intersection_0", + method="raw") +intersection_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment")) +intersection_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment")) + +intersection_1 = Intersection("intersection_1", + method="raw") +intersection_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment")) +intersection_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace="experiment")) + +feature_scale_0 = FeatureScale("feature_scale_0", + method="standard", + train_data=intersection_0.outputs["output_data"]) + +feature_scale_1 = FeatureScale("feature_scale_1", + test_data=intersection_1.outputs["output_data"], + input_model=feature_scale_0.outputs["output_model"]) + +statistics_0 = Statistics("statistics_0", input_data=feature_scale_1.outputs["test_output_data"], + metrics=["mean", "max", "std", "var", "kurtosis", "skewness"]) + +selection_0 = HeteroFeatureSelection(name="selection_0", train_data=intersection_0.outputs["output_data"], + method=["statistic"], + input_statistic_model=statistics_0.outputs["output_model"], + statistic_param={"metrics": ["mean", "max", "kurtosis", "skewness"]}) + +pipeline.add_task(intersection_0) +pipeline.add_task(intersection_1) +pipeline.add_task(feature_scale_0) +pipeline.add_task(feature_scale_1) +pipeline.add_task(statistics_0) +pipeline.add_task(selection_0) +pipeline.compile() +print(pipeline.get_dag()) +pipeline.fit() +print(json.dumps(pipeline.get_task_info("statistics_0").get_output_model(), indent=4)) +print(json.dumps(pipeline.get_task_info("selection_0").get_output_model(), indent=4)) + +predict_pipeline = FateFlowPipeline().set_roles(guest="9999", host="9998", arbiter="9998") +pipeline.deploy([intersection_0, selection_0]) +deployed_pipeline = pipeline.get_deployed_pipeline() + +predict_pipeline.add_task(deployed_pipeline) + +print("\n\n\n") +print(predict_pipeline.compile().get_dag()) +predict_pipeline.predict() + +print(predict_pipeline.get_task_info("selection_0").get_output_data()) diff --git a/python/fate/components/core/params/_filter_param.py b/python/fate/components/core/params/_filter_param.py index 6ea875fd74..4b6435dd4d 100644 --- a/python/fate/components/core/params/_filter_param.py +++ b/python/fate/components/core/params/_filter_param.py @@ -71,7 +71,7 @@ class IVFilterParam(FederatedStandardFilterParam, Parameter): class StatisticFilterParam(StandardFilterParam, Parameter): - metrics: List[statistic_metrics_param(describe=False)] = ["mean"] + metrics: List[statistic_metrics_param()] = ["mean"] class ManualFilterParam(pydantic.BaseModel, Parameter): From f463051534ab2d2d6244f6439b416a7e6ff6cf2c Mon Sep 17 00:00:00 2001 From: cwj Date: Fri, 7 Jul 2023 11:05:18 +0800 Subject: [PATCH 06/61] Add Homo LR files Signed-off-by: cwj --- python/fate/ml/glm/homo_lr/client.py | 54 ++++++++++++++++++++++++++++ python/fate/ml/glm/homo_lr/server.py | 0 2 files changed, 54 insertions(+) create mode 100644 python/fate/ml/glm/homo_lr/client.py create mode 100644 python/fate/ml/glm/homo_lr/server.py diff --git a/python/fate/ml/glm/homo_lr/client.py b/python/fate/ml/glm/homo_lr/client.py new file mode 100644 index 0000000000..6de8968128 --- /dev/null +++ b/python/fate/ml/glm/homo_lr/client.py @@ -0,0 +1,54 @@ +from typing import Optional +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from fate.ml.abc.module import HomoModule +from fate.arch import Context +import logging +import pandas as pd + + +logger = logging.getLogger(__name__) + + +class Data(object): + + def __init__(self, features: pd.DataFrame, sample_ids: pd.DataFrame, match_ids: pd.DataFrame, labels: pd.DataFrame) -> None: + # set var + self.features = features + self.sample_ids = sample_ids + self.match_ids = match_ids + self.labels = labels + + @staticmethod + def from_fate_dataframe(df: DataFrame): + schema = df.schema + sample_id = schema.sample_id_name + match_id = schema.match_id_name + label = schema.label_name + pd_df = df.as_pd_df() + features = pd_df.drop([sample_id, match_id, label], axis=1) + sample_ids = pd_df[[sample_id]] + match_ids = pd_df[[match_id]] + labels = pd_df[[label]] + return Data(features, sample_ids, match_ids, labels) + + +class HomoLRClient(HomoModule): + + def __init__(self) -> None: + super().__init__() + self.df_schema = None + self.train_data = None + self.validate_data = None + self.predict_data = None + + def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = None) -> None: + + train_pd_df = DataFrame.as_pd_df() + if validate_data is not None: + validate_pd_df = DataFrame.as_pd_df() + + + + def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: + return super().predict(ctx, predict_data) \ No newline at end of file diff --git a/python/fate/ml/glm/homo_lr/server.py b/python/fate/ml/glm/homo_lr/server.py new file mode 100644 index 0000000000..e69de29bb2 From 319cfd64db642d9be34e8fb04154c8567098cec9 Mon Sep 17 00:00:00 2001 From: cwj Date: Mon, 10 Jul 2023 19:27:44 +0800 Subject: [PATCH 07/61] Add HomoLR ML parts Signed-off-by: cwj --- python/fate/ml/evaluation/classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/fate/ml/evaluation/classification.py b/python/fate/ml/evaluation/classification.py index f998936e1e..c7ac33372b 100644 --- a/python/fate/ml/evaluation/classification.py +++ b/python/fate/ml/evaluation/classification.py @@ -43,7 +43,7 @@ class MultiAccuracy(Metric): def __call__(self, predict, label, **kwargs) -> Dict: predict = self.to_np_format(predict, flatten=False) - label = self.to_np_format(label) + label = self.to_np_format(label).astype(np.int32) if predict.shape != label.shape: predict = predict.argmax(axis=-1) acc = accuracy_score(label, predict) From bbb2d7de22decf5937960988979f7d3fbe40f9f1 Mon Sep 17 00:00:00 2001 From: cwj Date: Mon, 10 Jul 2023 19:27:57 +0800 Subject: [PATCH 08/61] Add HomoLR parts Signed-off-by: cwj --- fate_client | 2 +- fate_flow | 2 +- python/fate/ml/glm/homo_lr/client.py | 199 +++++++++++++++++- python/fate/ml/glm/homo_lr/server.py | 26 +++ python/fate/ml/glm/homo_lr/test/local_test.py | 27 +++ python/fate/ml/nn/trainer/trainer_base.py | 12 +- python/fate/ml/utils/model_io.py | 25 +++ 7 files changed, 281 insertions(+), 12 deletions(-) create mode 100644 python/fate/ml/glm/homo_lr/test/local_test.py create mode 100644 python/fate/ml/utils/model_io.py diff --git a/fate_client b/fate_client index fa2eeb98c6..7fcb28c933 160000 --- a/fate_client +++ b/fate_client @@ -1 +1 @@ -Subproject commit fa2eeb98c6a52abf7d6758b9643aa69903a7831c +Subproject commit 7fcb28c93331cd50e285e4e6c1bcd7ec8f8b896e diff --git a/fate_flow b/fate_flow index b8b0ecb82f..f067f8535d 160000 --- a/fate_flow +++ b/fate_flow @@ -1 +1 @@ -Subproject commit b8b0ecb82fa69a077abe7d23efd780eccc8abb92 +Subproject commit f067f8535d1ae463e7aca21e5c6a0ef4e0c8b6bf diff --git a/python/fate/ml/glm/homo_lr/client.py b/python/fate/ml/glm/homo_lr/client.py index 6de8968128..ad586123ea 100644 --- a/python/fate/ml/glm/homo_lr/client.py +++ b/python/fate/ml/glm/homo_lr/client.py @@ -1,10 +1,16 @@ -from typing import Optional +from typing import Optional, Union from fate.arch import Context from fate.arch.dataframe import DataFrame -from fate.ml.abc.module import HomoModule +from fate.ml.abc.module import HomoModule, Model, Module from fate.arch import Context import logging import pandas as pd +import torch as t +from fate.ml.nn.algo.homo.fedavg import FedAVGCLient, FedAVGServer, TrainingArguments, FedAVGArguments +from torch.utils.data import TensorDataset +import numpy as np +from torch.nn import functional as F +import functools logger = logging.getLogger(__name__) @@ -33,22 +39,203 @@ def from_fate_dataframe(df: DataFrame): return Data(features, sample_ids, match_ids, labels) +class HomoLRModel(t.nn.Module): + + def __init__(self, feature_num, label_num=2) -> None: + super().__init__() + assert feature_num >= 2 and isinstance(feature_num, int), "feature_num must be int greater than 2" + assert label_num >= 1 and isinstance(label_num, int), "label_num must be int greater than 1" + self.models = t.nn.ModuleList() + + if label_num <= 2 and label_num > 0: + self.models.append( + t.nn.Linear(feature_num, 1) + ) + else: + # OVR Setting + for i in range(label_num): + self.models.append( + t.nn.Linear(feature_num, 1) + ) + self.sigmoid = t.nn.Sigmoid() + self.softmax = t.nn.Softmax(dim=1) + + def forward(self, x): + + if len(self.models) == 1: + linear_out = self.models[0](x) + else: + linear_out = t.cat([model(x) for model in self.models], dim=1) + + linear_out = self.sigmoid(linear_out).reshape((-1, len(self.models))) + + if not self.training: + prob = self.softmax(linear_out) + return prob + else: + return linear_out + + def to_dict(self): + model_dict = { + "feature_num": self.models[0].in_features, + "label_num": len(self.models), + "state_dict": {k: v.tolist() for k, v in self.state_dict().items()} # convert tensor to list + } + return model_dict + + @classmethod + def from_dict(cls, model_dict): + model = cls(model_dict["feature_num"], model_dict["label_num"]) + model_state_dict = {k: t.tensor(v) for k, v in model_dict["state_dict"].items()} # convert list back to tensor + model.load_state_dict(model_state_dict) + return model + + +def homo_lr_loss(pred, labels, dim=1): + """ + The function assumes that pred has shape (n, num_classes) where each class has its own linear model. + labels have shape (n,) and the values are integers denoting the class. + """ + + # initialize the loss + loss = 0.0 + if dim == 2: + dim -= 1 + + loss_fn = t.nn.BCELoss() + + for c in range(dim): + # get binary labels for this class + binary_labels = (labels == c).float().flatten() + bin_pred = pred[:, c].flatten() + # compute binary cross-entropy loss + loss = loss_fn(bin_pred, binary_labels) + # normalize loss by the number of classes + loss /= dim + + return loss + + +def optimizer_to_dict(optimizer): + # Convert the optimizer state to a dictionary that can be transformed to JSON + optimizer_dict = { + "state": {k: v.tolist() for k, v in optimizer.state_dict()['state'].items()}, + "param_groups": optimizer.state_dict()['param_groups'], + } + return optimizer_dict + + class HomoLRClient(HomoModule): - def __init__(self) -> None: + def __init__(self, max_iter: int, batch_size: int, optimizer_param=None, + learning_rate_param=None, + init_param=None, + threshold=0.5 + ) -> None: + super().__init__() self.df_schema = None self.train_data = None self.validate_data = None self.predict_data = None + # set vars + self.max_iter = max_iter + self.batch_size = batch_size + self.optimizer_param = optimizer_param + self.learning_rate_param = learning_rate_param + self.init_param = init_param + self.threshold = threshold + self.run_ovr = False + self.train_feature_num = None + self.validate_feature_num = None + + # models & optimizer & schduler + self.model = None + self.optimizer = None + self.scheduler = None + + # checkping param + assert self.max_iter > 0 and isinstance(self.max_iter, int), "max_iter must be int greater than 0" + assert self.batch_size > 0 and isinstance(self.batch_size, int), "batch_size must be int greater than 0" + assert self.threshold > 0 and self.threshold < 1, "threshold must be float between 0 and 1" + + + def _make_dataset(self, data: Data): + + X = np.array(data.features.values).astype(np.float32) + y = np.array(data.labels.values).astype(np.float32) + X_tensor = t.tensor(X, dtype=t.float32) + y_tensor = t.tensor(y.reshape((-1, 1)), dtype=t.float32) + dataset = TensorDataset(X_tensor, y_tensor) + return dataset + def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = None) -> None: - train_pd_df = DataFrame.as_pd_df() + self.train_data = Data.from_fate_dataframe(train_data) + self.train_feature_num = self.train_data.features.values.shape[1] + if validate_data is not None: + self.validate_data = Data.from_fate_dataframe(validate_data) + self.validate_feature_num = self.validate_data.features.values.shape[1] + assert self.train_feature_num == self.validate_feature_num, "train and validate feature num not match: {} vs {}".format(self.train_feature_num, self.validate_feature_num) + + unique_label_set = set(self.train_data.labels.values.reshape(-1)) if validate_data is not None: - validate_pd_df = DataFrame.as_pd_df() + unique_label_set = unique_label_set.union(set(self.validate_data.labels.values.reshape(-1))) + logger.info("unique label set updated to: {}".format(unique_label_set)) + + train_set = self._make_dataset(self.train_data) + + if self.validate_data is not None: + validate_set = self._make_dataset(self.validate_data) + else: + validate_set = None + loss_fn = functools.partial(homo_lr_loss, dim=len(unique_label_set)) + + model = HomoLRModel(self.train_feature_num, label_num=len(unique_label_set)) + self.model = model + logger.info('model structure is {}'.format(model)) + + optimizer = t.optim.SGD(model.parameters(), lr=self.learning_rate_param) + self.optimizer = optimizer + # training + fed_arg = FedAVGArguments() + train_arg = TrainingArguments(num_train_epochs=self.max_iter, + per_device_train_batch_size=self.batch_size, per_gpu_eval_batch_size=self.batch_size) + trainer = FedAVGCLient(ctx, model=model, loss_fn=loss_fn, optimizer=optimizer, train_set=train_set, + val_set=validate_set, training_args=train_arg, fed_args=fed_arg) + + # !!!!!!!!!! + # TODO + # !!!!!!!!!! + trainer.set_local_mode() + trainer.train() def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: - return super().predict(ctx, predict_data) \ No newline at end of file + return super().predict(ctx, predict_data) + + def get_model(self) -> dict: + param = {} + if self.model is not None: + param['model'] = self.model.to_dict() + if self.optimizer is not None: + param['optimizer'] = optimizer_to_dict(self.optimizer) + + meta = {'batch_size': self.batch_size, 'max_iter': self.max_iter, 'threshold': self.threshold, + 'optimizer_param': self.optimizer_param, 'learning_rate_param': self.learning_rate_param, 'init_param': self.init_param} + ret = {'meta': meta, 'param': param} + + return ret + + @classmethod + def from_model(cls, model: dict) -> Module: + if not hasattr(model, 'model'): + raise ('key "param" is not found in the input model dict') + param = model['param'] + if not hasattr(param, 'model'): + raise ValueError("param dict must have key 'model' that contains the model parameter and structure info") + + + diff --git a/python/fate/ml/glm/homo_lr/server.py b/python/fate/ml/glm/homo_lr/server.py index e69de29bb2..5f70397491 100644 --- a/python/fate/ml/glm/homo_lr/server.py +++ b/python/fate/ml/glm/homo_lr/server.py @@ -0,0 +1,26 @@ +from fate.ml.abc.module import HomoModule +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from fate.arch import Context +import logging +from fate.ml.nn.algo.homo.fedavg import FedAVGServer + + +logger = logging.getLogger(__name__) + +class HomoLRServer(HomoModule): + + def __init__(self) -> None: + pass + + def fit(self, ctx: Context, data: DataFrame) -> None: + + + server = FedAVGServer(ctx=ctx) + logger.info('server class init done, start fed training') + server.train() + logger.info('homo lr fit done') + + def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: + + logger.info('kkip prediction stage') \ No newline at end of file diff --git a/python/fate/ml/glm/homo_lr/test/local_test.py b/python/fate/ml/glm/homo_lr/test/local_test.py new file mode 100644 index 0000000000..8bbf5d4fb1 --- /dev/null +++ b/python/fate/ml/glm/homo_lr/test/local_test.py @@ -0,0 +1,27 @@ +from fate.arch import Context +from fate.arch.computing.standalone import CSession +from fate.arch.context import Context +from fate.arch.federation.standalone import StandaloneFederation +import pandas as pd +from fate.arch.dataframe import PandasReader +from fate.ml.glm.homo_lr.client import HomoLRClient, HomoLRModel + + + +computing = CSession() +ctx = Context( + "guest", + computing=computing, + federation=StandaloneFederation(computing, "fed", ("guest", 10000), [("host", 9999)]), +) + +df = pd.read_csv('./examples/data/vehicle_scale_homo_guest.csv') +df['sample_id'] = [i for i in range(len(df))] + +reader = PandasReader(sample_id_name='sample_id', match_id_name="id", label_name="y", dtype="object") +data = reader.to_frame(ctx, df) +df = data.as_pd_df() + +client = HomoLRClient(50, 32, learning_rate_param=0.01) +a = client.fit(ctx, data) + diff --git a/python/fate/ml/nn/trainer/trainer_base.py b/python/fate/ml/nn/trainer/trainer_base.py index 8fefff2895..232ab35ecf 100644 --- a/python/fate/ml/nn/trainer/trainer_base.py +++ b/python/fate/ml/nn/trainer/trainer_base.py @@ -838,11 +838,15 @@ def compute_loss(self, model, inputs, **kwargs): if self._use_hf_default_behavior: return super().compute_loss(model, inputs, **kwargs) else: + # (features, labels), this format is used in FATE-1.x + if isinstance(inputs, tuple) or isinstance(inputs, list) and len(inputs) == 2: + feats, labels = inputs + output = model(feats) + loss = self.loss_func(output, labels) + return loss + else: + return super().compute_loss(model, inputs, **kwargs) - feats, labels = inputs - logits = model(feats) - loss = self.loss_func(logits, labels) - return loss def prediction_step(self, model: nn.Module, diff --git a/python/fate/ml/utils/model_io.py b/python/fate/ml/utils/model_io.py new file mode 100644 index 0000000000..6e5189f79d --- /dev/null +++ b/python/fate/ml/utils/model_io.py @@ -0,0 +1,25 @@ +from typing import Optional + + +class ModelExporter: + _META = "meta" + _DATA = "data" + + def __init__(self, data: dict, meta: Optional[dict] = None): + self.data = data + self.meta = meta + + def dict(self): + return { + self._DATA: self.data, + self._META: self.meta if self.meta is not None else {}, + } + + @classmethod + def from_dict(cls, d: dict): + data = d[cls._DATA] + if cls._META in d: + meta = d[cls._META] + else: + meta = None + return cls(data, meta) \ No newline at end of file From f0b2833c68fce1a297616bdc1048eeddcc925c03 Mon Sep 17 00:00:00 2001 From: cwj Date: Mon, 10 Jul 2023 19:31:44 +0800 Subject: [PATCH 09/61] Add Model Exporter Signed-off-by: cwj --- python/fate/ml/abc/module.py | 3 +-- python/fate/ml/glm/homo_lr/client.py | 7 ++++--- python/fate/ml/utils/model_io.py | 5 ++++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/fate/ml/abc/module.py b/python/fate/ml/abc/module.py index e295b7ed83..17f30a8236 100644 --- a/python/fate/ml/abc/module.py +++ b/python/fate/ml/abc/module.py @@ -39,8 +39,7 @@ def transform(self, ctx: Context, transform_data: DataFrame) -> DataFrame: def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: ... - @classmethod - def from_model(cls, model: Union[dict, Model]) -> "Module": + def from_model(cls, model: Union[dict, Model]): ... def get_model(self) -> Union[dict, Model]: diff --git a/python/fate/ml/glm/homo_lr/client.py b/python/fate/ml/glm/homo_lr/client.py index ad586123ea..ec753a7523 100644 --- a/python/fate/ml/glm/homo_lr/client.py +++ b/python/fate/ml/glm/homo_lr/client.py @@ -1,7 +1,8 @@ from typing import Optional, Union from fate.arch import Context from fate.arch.dataframe import DataFrame -from fate.ml.abc.module import HomoModule, Model, Module +from fate.ml.abc.module import HomoModule, Module +from fate.ml.utils.model_io import ModelExporter from fate.arch import Context import logging import pandas as pd @@ -225,9 +226,9 @@ def get_model(self) -> dict: meta = {'batch_size': self.batch_size, 'max_iter': self.max_iter, 'threshold': self.threshold, 'optimizer_param': self.optimizer_param, 'learning_rate_param': self.learning_rate_param, 'init_param': self.init_param} - ret = {'meta': meta, 'param': param} + export_ = ModelExporter(data=param, meta=meta) - return ret + return export_ @classmethod def from_model(cls, model: dict) -> Module: diff --git a/python/fate/ml/utils/model_io.py b/python/fate/ml/utils/model_io.py index 6e5189f79d..898a07bffe 100644 --- a/python/fate/ml/utils/model_io.py +++ b/python/fate/ml/utils/model_io.py @@ -22,4 +22,7 @@ def from_dict(cls, d: dict): meta = d[cls._META] else: meta = None - return cls(data, meta) \ No newline at end of file + return cls(data, meta) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(data={self.data}, meta={self.meta})" \ No newline at end of file From 59189338d877c7500956fe0bacd2debfed4906cb Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Tue, 11 Jul 2023 16:26:51 +0800 Subject: [PATCH 10/61] get model type from meta in ml(#4661) Signed-off-by: Yu Wu --- .../components/hetero_feature_selection.py | 21 ++--- .../hetero_feature_selection.py | 85 ++++++++++++------- 2 files changed, 60 insertions(+), 46 deletions(-) diff --git a/python/fate/components/components/hetero_feature_selection.py b/python/fate/components/components/hetero_feature_selection.py index bfb169331e..6e94de4d3f 100644 --- a/python/fate/components/components/hetero_feature_selection.py +++ b/python/fate/components/components/hetero_feature_selection.py @@ -30,9 +30,9 @@ def train( role: Role, train_data: cpn.dataframe_input(roles=[GUEST, HOST]), input_models: cpn.json_model_inputs(roles=[GUEST, HOST]), - method: cpn.parameter(type=List[params.string_choice(["manual", "binning", "statistic"])], + method: cpn.parameter(type=List[params.string_choice(["manual", "binning", "statistics"])], default=["manual"], optional=False, - desc="selection method, options: {manual, binning, statistic}"), + desc="selection method, options: {manual, binning, statistics}"), select_col: cpn.parameter(type=List[str], default=None, desc="list of column names to be selected, if None, all columns will be considered"), iv_param: cpn.parameter(type=params.iv_filter_param(), @@ -49,7 +49,7 @@ def train( desc="statistic filter param"), manual_param: cpn.parameter(type=params.manual_filter_param(), default=params.ManualFilterParam(filter_out_col=[], keep_col=[]), - desc="note that manual filter will always be processed as the last filter"), + desc="manual filter param"), keep_one: cpn.parameter(type=bool, default=True, desc="whether to keep at least one feature among `select_col`"), @@ -62,11 +62,6 @@ def train( from fate.ml.feature_selection import HeteroSelectionModuleHost, HeteroSelectionModuleGuest sub_ctx = ctx.sub_ctx("train") - isometric_model_dict = {} - for model in input_models: - model_type = model.artifact.metadata.metadata - model = model.read() - isometric_model_dict[model_type] = model train_data = train_data.read() columns = train_data.schema.columns.to_list() @@ -80,20 +75,20 @@ def train( if manual_param.keep_col is not None: keep_col = [columns[anonymous_columns.index(col)] for col in manual_param.keep_col] manual_param.keep_col = keep_col - + input_models = [model.read() for model in input_models] if role.is_guest: - selection = HeteroSelectionModuleGuest(method, select_col, isometric_model_dict, + selection = HeteroSelectionModuleGuest(method, select_col, input_models, iv_param, statistic_param, manual_param, keep_one) elif role.is_host: - selection = HeteroSelectionModuleHost(method, select_col, isometric_model_dict, + selection = HeteroSelectionModuleHost(method, select_col, input_models, iv_param, statistic_param, manual_param, keep_one) else: raise ValueError(f"role: {role} is not valid") selection.fit(sub_ctx, train_data) - model = selection.to_model() - train_output_model.write(model, metadata={"method": method}) + model = selection.get_model() + train_output_model.write(model, metadata={}) sub_ctx = ctx.sub_ctx("predict") output_data = train_data diff --git a/python/fate/ml/feature_selection/hetero_feature_selection.py b/python/fate/ml/feature_selection/hetero_feature_selection.py index e9a2f7feb4..8c9d64e015 100644 --- a/python/fate/ml/feature_selection/hetero_feature_selection.py +++ b/python/fate/ml/feature_selection/hetero_feature_selection.py @@ -21,7 +21,7 @@ import numpy as np import pandas as pd -from fate.interface import Context +from fate.arch import Context from ..abc.module import Module, HeteroModule logger = logging.getLogger(__name__) @@ -30,20 +30,28 @@ class HeteroSelectionModuleGuest(HeteroModule): - def __init__(self, method=None, select_col=None, isometric_model_dict=None, + def __init__(self, method=None, select_col=None, input_models=None, iv_param=None, statistic_param=None, manual_param=None, keep_one=True): self.method = method self.select_col = select_col - self.isometric_model_dict = isometric_model_dict self.iv_param = iv_param self.statistic_param = statistic_param self.manual_param = manual_param self.keep_one = keep_one + # keep selection history self._inner_method = [] self._selection_obj = [] + isometric_model_dict = {} + for model in input_models: + model_type = model["meta"].get("model_type") + if model_type is None: + raise ValueError(f"Missing 'model_type' in input model") + isometric_model_dict[model_type] = model + self.isometric_model_dict = isometric_model_dict + def fit(self, ctx: Context, train_data, validate_data=None) -> None: logger.info(f"isometric_model_dict: {self.isometric_model_dict}") if self.select_col is None: @@ -66,10 +74,10 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: param=self.iv_param, model=model, keep_one=self.keep_one) - elif filter_type == "statistic": - model = self.isometric_model_dict.get("statistic", None) + elif filter_type == "statistics": + model = self.isometric_model_dict.get("statistics", None) if model is None: - raise ValueError(f"Cannot find statistic model in input, please check") + raise ValueError(f"Cannot find statistics model in input, please check") selection_obj = StandardSelection(method=filter_type, header=header, param=self.statistic_param, @@ -103,15 +111,18 @@ def transform(self, ctx: Context, test_data): transformed_data = self._selection_obj[-1].transform(ctx, test_data) return transformed_data - def to_model(self): + def get_model(self): # all selection obj need to be recorded for display of cascade order selection_obj_list = [] for selection_obj in self._selection_obj: selection_obj_list.append(selection_obj.to_model()) - return {"selection_obj_list": json.dumps(selection_obj_list), - "method": self.method, - "select_col": self.select_col, + data = {"selection_obj_list": json.dumps(selection_obj_list), "inner_method": self._inner_method} + meta = {"method": self.method, + "select_col": self.select_col, + "keep_one": self.keep_one} + return {"data": data, + "meta": meta} def restore(self, model): selection_obj_list = [] @@ -127,26 +138,33 @@ def restore(self, model): @classmethod def from_model(cls, model) -> "HeteroSelectionModuleGuest": - selection_obj = HeteroSelectionModuleGuest(model["method"], model["select_col"]) - selection_obj._inner_method = model["inner_method"] - selection_obj.restore(model) + selection_obj = HeteroSelectionModuleGuest(model["meta"]["method"], model["meta"]["select_col"]) + selection_obj._inner_method = model["data"]["inner_method"] + selection_obj.restore(model["data"]) return selection_obj class HeteroSelectionModuleHost(HeteroModule): - def __init__(self, method=None, select_col=None, isometric_model_dict=None, + def __init__(self, method=None, select_col=None, input_models=None, iv_param=None, statistic_param=None, manual_param=None, keep_one=True): self.method = method - self.isometric_model_dict = isometric_model_dict self.iv_param = iv_param self.statistic_param = statistic_param self.manual_param = manual_param self.keep_one = keep_one self.select_col = select_col - # for display of cascade order - self._inner_method = [None] * len(method) - self._selection_obj = [None] * len(method) + # keep selection history + self._inner_method = [] + self._selection_obj = [] + + isometric_model_dict = {} + for model in input_models: + model_type = model["meta"].get("model_type") + if model_type is None: + raise ValueError(f"Missing 'model_type' in input model") + isometric_model_dict[model_type] = model + self.isometric_model_dict = isometric_model_dict def fit(self, ctx: Context, train_data, validate_data=None) -> None: if self.select_col is None: @@ -159,8 +177,6 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: header=header, param=self.manual_param, keep_one=self.keep_one) - self._selection_obj[i] = selection_obj - self._inner_method[i] = "manual" elif filter_type == "iv": model = self.isometric_model_dict.get("binning", None) if model is None: @@ -170,21 +186,20 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: param=self.iv_param, model=model, keep_one=self.keep_one) - self._selection_obj[i] = selection_obj - self._inner_method[i] = "iv" elif filter_type == "statistic": - model = self.isometric_model_dict.get("statistic", None) + model = self.isometric_model_dict.get("statistics", None) if model is None: - raise ValueError(f"Cannot find statistic model in input, please check") + raise ValueError(f"Cannot find statistics model in input, please check") selection_obj = StandardSelection(method=filter_type, header=header, param=self.statistic_param, model=model, keep_one=self.keep_one) - self._selection_obj[i] = selection_obj - self._inner_method[i] = "statistic" + else: raise ValueError(f"{type} selection method not supported, please check") + self._selection_obj.append(selection_obj) + self._inner_method.append(filter_type) prev_selection_obj = None for method, selection_obj in zip(self._inner_method, self._selection_obj): @@ -215,15 +230,19 @@ def transform(self, ctx: Context, test_data): transformed_data = self._selection_obj[-1].transform(ctx, test_data) return transformed_data - def to_model(self): + def get_model(self): # all selection history need to be recorded for display selection_obj_list = [] for selection_obj in self._selection_obj: selection_obj_list.append(selection_obj.to_model()) - return {"selection_obj_list": json.dumps(selection_obj_list), - "method": self.method, - "select_col": self.select_col, + + data = {"selection_obj_list": json.dumps(selection_obj_list), "inner_method": self._inner_method} + meta = {"method": self.method, + "select_col": self.select_col, + "keep_one": self.keep_one} + return {"data": data, + "meta": meta} def restore(self, model): selection_obj_list = [] @@ -239,9 +258,9 @@ def restore(self, model): @classmethod def from_model(cls, model) -> "HeteroSelectionModuleHost": - selection_obj = HeteroSelectionModuleHost(model["method"], model["select_col"]) - selection_obj._inner_method = model["inner_method"] - selection_obj.restore(model) + selection_obj = HeteroSelectionModuleHost(model["meta"]["method"], model["meta"]["select_col"]) + selection_obj._inner_method = model["data"]["inner_method"] + selection_obj.restore(model["data"]) return selection_obj From 450fac9842688f3f339878a1d788ddfb6701d6e7 Mon Sep 17 00:00:00 2001 From: cwj Date: Tue, 11 Jul 2023 17:08:02 +0800 Subject: [PATCH 11/61] 1. Update HomoLR 2. Refactor HomoNN parameter & add new callback 3. ModelO class Signed-off-by: cwj --- python/fate/ml/glm/homo_lr/client.py | 247 ++++++++++++------ python/fate/ml/glm/homo_lr/test/local_test.py | 28 +- python/fate/ml/nn/algo/homo/fedavg.py | 11 +- python/fate/ml/nn/trainer/trainer_base.py | 51 ++-- python/fate/ml/utils/model_io.py | 2 +- 5 files changed, 239 insertions(+), 100 deletions(-) diff --git a/python/fate/ml/glm/homo_lr/client.py b/python/fate/ml/glm/homo_lr/client.py index ec753a7523..9b3d3d33cf 100644 --- a/python/fate/ml/glm/homo_lr/client.py +++ b/python/fate/ml/glm/homo_lr/client.py @@ -1,17 +1,19 @@ -from typing import Optional, Union +import torch.nn as nn from fate.arch import Context from fate.arch.dataframe import DataFrame -from fate.ml.abc.module import HomoModule, Module -from fate.ml.utils.model_io import ModelExporter +from fate.ml.abc.module import HomoModule +from fate.ml.utils.model_io import ModelIO from fate.arch import Context import logging import pandas as pd import torch as t -from fate.ml.nn.algo.homo.fedavg import FedAVGCLient, FedAVGServer, TrainingArguments, FedAVGArguments -from torch.utils.data import TensorDataset +from fate.ml.nn.algo.homo.fedavg import FedAVGCLient, TrainingArguments, FedAVGArguments +from transformers import default_data_collator import numpy as np from torch.nn import functional as F import functools +import tempfile +from torch.utils.data import Dataset logger = logging.getLogger(__name__) @@ -40,15 +42,40 @@ def from_fate_dataframe(df: DataFrame): return Data(features, sample_ids, match_ids, labels) +def homo_lr_loss(pred, labels, dim=1): + """ + The function assumes that pred has shape (n, num_classes) where each class has its own linear model. + labels have shape (n,) and the values are integers denoting the class. + """ + + # initialize the loss + loss = 0.0 + if dim == 2: + dim -= 1 + + loss_fn = t.nn.BCELoss() + + for c in range(dim): + # get binary labels for this class + binary_labels = (labels == c).float().flatten() + bin_pred = pred[:, c].flatten() + # compute binary cross-entropy loss + loss = loss_fn(bin_pred, binary_labels) + # normalize loss by the number of classes + loss /= dim + + return loss + + class HomoLRModel(t.nn.Module): - def __init__(self, feature_num, label_num=2) -> None: + def __init__(self, feature_num, label_num=2, l1=0) -> None: super().__init__() assert feature_num >= 2 and isinstance(feature_num, int), "feature_num must be int greater than 2" assert label_num >= 1 and isinstance(label_num, int), "label_num must be int greater than 1" self.models = t.nn.ModuleList() - if label_num <= 2 and label_num > 0: + if 2 >= label_num > 0: self.models.append( t.nn.Linear(feature_num, 1) ) @@ -60,21 +87,34 @@ def __init__(self, feature_num, label_num=2) -> None: ) self.sigmoid = t.nn.Sigmoid() self.softmax = t.nn.Softmax(dim=1) + self.l1 = l1 - def forward(self, x): + def forward(self, x, labels=None): if len(self.models) == 1: linear_out = self.models[0](x) else: linear_out = t.cat([model(x) for model in self.models], dim=1) + ret_dict = {} linear_out = self.sigmoid(linear_out).reshape((-1, len(self.models))) if not self.training: - prob = self.softmax(linear_out) - return prob - else: - return linear_out + if len(self.models) > 1: + linear_out = self.softmax(linear_out) + + ret_dict['pred'] = linear_out + + if labels is not None: + loss = homo_lr_loss(linear_out, labels, dim=len(self.models)) + if self.l1 != 0: + l1_regularization = t.tensor(0.) + for param in self.models.parameters(): + l1_regularization += t.norm(param, 1) + loss += self.l1 * l1_regularization + ret_dict['loss'] = loss + + return ret_dict def to_dict(self): model_dict = { @@ -92,46 +132,68 @@ def from_dict(cls, model_dict): return model -def homo_lr_loss(pred, labels, dim=1): - """ - The function assumes that pred has shape (n, num_classes) where each class has its own linear model. - labels have shape (n,) and the values are integers denoting the class. - """ +def init_model(model, method='random', val=1.0): + if method == 'zeros': + init_fn = nn.init.zeros_ + elif method == 'ones': + init_fn = nn.init.ones_ + elif method == 'consts': + init_fn = lambda x: nn.init.constant_(x, val) + elif method == 'random': + init_fn = nn.init.normal_ + else: + raise ValueError("Invalid method. Options are: 'zeros', 'ones', 'consts', 'random'") - # initialize the loss - loss = 0.0 - if dim == 2: - dim -= 1 + for name, param in model.named_parameters(): + if 'bias' in name: + nn.init.zeros_(param) # usually it's good practice to initialize biases to zero + else: + init_fn(param) - loss_fn = t.nn.BCELoss() - for c in range(dim): - # get binary labels for this class - binary_labels = (labels == c).float().flatten() - bin_pred = pred[:, c].flatten() - # compute binary cross-entropy loss - loss = loss_fn(bin_pred, binary_labels) - # normalize loss by the number of classes - loss /= dim +# read model from model bytes +def recover_torch_bytes(model_bytes): - return loss + with tempfile.TemporaryFile() as f: + f.write(model_bytes) + f.seek(0) + model_dict = t.load(f) + + return model_dict + + +def get_torch_bytes(model_dict): + + with tempfile.TemporaryFile() as f: + t.save(model_dict, f) + f.seek(0) + model_saved_bytes = f.read() + + return model_saved_bytes + + +class DictDataset(Dataset): + """TensorDataset with support of transforms. + """ + def __init__(self, data): + self.X = np.array(data.features.values).astype(np.float32) + self.y = np.array(data.labels.values).astype(np.float32) + self.X_tensor = t.tensor(self.X, dtype=t.float32) + self.y_tensor = t.tensor(self.y.reshape((-1, 1)), dtype=t.float32) + def __getitem__(self, index): + return {'x': self.X_tensor[index], 'label': self.y_tensor[index]} -def optimizer_to_dict(optimizer): - # Convert the optimizer state to a dictionary that can be transformed to JSON - optimizer_dict = { - "state": {k: v.tolist() for k, v in optimizer.state_dict()['state'].items()}, - "param_groups": optimizer.state_dict()['param_groups'], - } - return optimizer_dict + def __len__(self): + return self.X_tensor.shape[0] class HomoLRClient(HomoModule): - def __init__(self, max_iter: int, batch_size: int, optimizer_param=None, + def __init__(self, max_iter: int=5, batch_size: int=32, optimizer_param=None, learning_rate_param=None, init_param=None, - threshold=0.5 + threshold: float=0.5 ) -> None: super().__init__() @@ -155,21 +217,23 @@ def __init__(self, max_iter: int, batch_size: int, optimizer_param=None, self.model = None self.optimizer = None self.scheduler = None + self.optimizer_state_dict = None + self.trainer = None + + # loaded meta + self.loaded_meta = None + + # l1 & l2 + self.l1 = 0 + self.l2 = 0 # checkping param assert self.max_iter > 0 and isinstance(self.max_iter, int), "max_iter must be int greater than 0" assert self.batch_size > 0 and isinstance(self.batch_size, int), "batch_size must be int greater than 0" assert self.threshold > 0 and self.threshold < 1, "threshold must be float between 0 and 1" - def _make_dataset(self, data: Data): - - X = np.array(data.features.values).astype(np.float32) - y = np.array(data.labels.values).astype(np.float32) - X_tensor = t.tensor(X, dtype=t.float32) - y_tensor = t.tensor(y.reshape((-1, 1)), dtype=t.float32) - dataset = TensorDataset(X_tensor, y_tensor) - return dataset + return DictDataset(data) def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = None) -> None: @@ -192,51 +256,84 @@ def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = No else: validate_set = None + # prepare loss function loss_fn = functools.partial(homo_lr_loss, dim=len(unique_label_set)) - model = HomoLRModel(self.train_feature_num, label_num=len(unique_label_set)) - self.model = model - logger.info('model structure is {}'.format(model)) + # initialize model + if self.model is None: + + self.model = HomoLRModel(self.train_feature_num, label_num=len(unique_label_set), l1=self.l1) - optimizer = t.optim.SGD(model.parameters(), lr=self.learning_rate_param) - self.optimizer = optimizer + # init model here + init_model(self.model) + + logger.info('model initialized') + logger.info('model parameters are {}'.format(list(self.model.parameters()))) + else: + logger.info('model is loaded') + logger.info('model structure is {}'.format(self.model)) + + # initialize optimizer + self.optimizer = t.optim.SGD(self.model.parameters(), lr=self.learning_rate_param, weight_decay=self.l2) + if self.optimizer_state_dict is not None: + optimizer_state_dict = { + "state": {k: t.tensor(v) for k, v in self.optimizer_state_dict['state'].items()}, + "param_groups": self.optimizer_state_dict['param_groups'], + } + self.optimizer.load_state_dict(optimizer_state_dict) + logger.info('load warmstart optimizer state dict') # training fed_arg = FedAVGArguments() train_arg = TrainingArguments(num_train_epochs=self.max_iter, per_device_train_batch_size=self.batch_size, per_gpu_eval_batch_size=self.batch_size) - trainer = FedAVGCLient(ctx, model=model, loss_fn=loss_fn, optimizer=optimizer, train_set=train_set, - val_set=validate_set, training_args=train_arg, fed_args=fed_arg) + self.trainer = FedAVGCLient(ctx, model=self.model, loss_fn=loss_fn, optimizer=self.optimizer, train_set=train_set, + val_set=validate_set, training_args=train_arg, fed_args=fed_arg, data_collator=default_data_collator) + self.trainer.set_local_mode() + self.trainer.train() - # !!!!!!!!!! - # TODO - # !!!!!!!!!! - trainer.set_local_mode() - trainer.train() - - def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: - return super().predict(ctx, predict_data) - - def get_model(self) -> dict: + + if self.model is None: + raise ValueError("model is not initialized") + self.predict_data = Data.from_fate_dataframe(predict_data) + predict_set = self._make_dataset(self.predict_data) + if self.trainer is None: + train_arg = TrainingArguments(num_train_epochs=self.max_iter, per_device_eval_batch_size=self.batch_size) + trainer = FedAVGCLient(ctx, train_set=predict_set, model=self.model, training_args=train_arg, + fed_args=FedAVGArguments(), data_collator=default_data_collator) + trainer.set_local_mode() + else: + trainer = self.trainer + predict_rs = trainer.predict(predict_set) + + return predict_rs + + def get_model(self) -> ModelIO: param = {} if self.model is not None: param['model'] = self.model.to_dict() if self.optimizer is not None: - param['optimizer'] = optimizer_to_dict(self.optimizer) + param['optimizer'] = get_torch_bytes(self.optimizer.state_dict()) meta = {'batch_size': self.batch_size, 'max_iter': self.max_iter, 'threshold': self.threshold, 'optimizer_param': self.optimizer_param, 'learning_rate_param': self.learning_rate_param, 'init_param': self.init_param} - export_ = ModelExporter(data=param, meta=meta) + export_ = ModelIO(data=param, meta=meta) return export_ - - @classmethod - def from_model(cls, model: dict) -> Module: - if not hasattr(model, 'model'): - raise ('key "param" is not found in the input model dict') - param = model['param'] - if not hasattr(param, 'model'): - raise ValueError("param dict must have key 'model' that contains the model parameter and structure info") + + def from_model(self, model: ModelIO): + + model = model.dict() + if not 'data' in model: + raise ('key "data" is not found in the input model dict') + model_param = model['data'] + if not 'model' in model_param: + raise ValueError("param dict must have key 'model' that contains the model parameter and structure info") + self.model = HomoLRModel.from_dict(model_param['model']) + self.model.l1 = self.l1 + if hasattr(model_param, 'optimizer'): + self.optimizer_state_dict = recover_torch_bytes(model_param['optimizer']) + self.loaded_meta = model['meta'] diff --git a/python/fate/ml/glm/homo_lr/test/local_test.py b/python/fate/ml/glm/homo_lr/test/local_test.py index 8bbf5d4fb1..15810c8629 100644 --- a/python/fate/ml/glm/homo_lr/test/local_test.py +++ b/python/fate/ml/glm/homo_lr/test/local_test.py @@ -5,7 +5,18 @@ import pandas as pd from fate.arch.dataframe import PandasReader from fate.ml.glm.homo_lr.client import HomoLRClient, HomoLRModel +import logging +import logging + +# Get the root logger +logger = logging.getLogger() +logger.setLevel(logging.INFO) +ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +ch.setFormatter(formatter) +logger.addHandler(ch) computing = CSession() @@ -15,13 +26,24 @@ federation=StandaloneFederation(computing, "fed", ("guest", 10000), [("host", 9999)]), ) -df = pd.read_csv('./examples/data/vehicle_scale_homo_guest.csv') +df = pd.read_csv('./../../../../../../examples/data/breast_homo_guest.csv') df['sample_id'] = [i for i in range(len(df))] reader = PandasReader(sample_id_name='sample_id', match_id_name="id", label_name="y", dtype="object") data = reader.to_frame(ctx, df) df = data.as_pd_df() -client = HomoLRClient(50, 32, learning_rate_param=0.01) -a = client.fit(ctx, data) +client = HomoLRClient(50, 800, learning_rate_param=0.01) +client.l2 = 0.01 +client.l1 = 0.01 +client.fit(ctx, data) +export_model = client.get_model() +pred = client.predict(ctx, data) + +# print('load model and warm-starting') +# client_2 = HomoLRClient(1, batch_size=800, learning_rate_param=0.001) +# client_2.from_model(export_model) +# client_2.fit(ctx, data) +# from fate.components.core.params._learning_rate import LRSchedulerParam +# from fate.components.core.params._optimizer import OptimizerParam diff --git a/python/fate/ml/nn/algo/homo/fedavg.py b/python/fate/ml/nn/algo/homo/fedavg.py index 12e18642e0..15ddfeddc4 100644 --- a/python/fate/ml/nn/algo/homo/fedavg.py +++ b/python/fate/ml/nn/algo/homo/fedavg.py @@ -39,9 +39,12 @@ class FedAVGCLient(FedTrainerClient): def __init__(self, ctx: Context, - model: Module, loss_fn: Module, optimizer: Optimizer, + model: Module, training_args: TrainingArguments, fed_args: FedArguments, - train_set: Dataset, val_set: Dataset = None, + train_set: Dataset, + val_set: Dataset = None, + loss_fn: Module = None, + optimizer: Optimizer = None, scheduler: _LRScheduler = None, callbacks: List[TrainerCallback] = [], data_collator: Callable=None, @@ -51,8 +54,8 @@ def __init__(self, local_mode: bool = False ): - super().__init__(ctx, model, loss_fn, optimizer, training_args, fed_args, train_set, val_set, data_collator, tokenizer, - scheduler, callbacks, use_hf_default_behavior, + super().__init__(ctx, model, training_args, fed_args, train_set, val_set, loss_fn, optimizer, data_collator, scheduler, + tokenizer, callbacks, use_hf_default_behavior, compute_metrics=compute_metrics, local_mode=local_mode) def init_aggregator(self): diff --git a/python/fate/ml/nn/trainer/trainer_base.py b/python/fate/ml/nn/trainer/trainer_base.py index 232ab35ecf..3a8a04a545 100644 --- a/python/fate/ml/nn/trainer/trainer_base.py +++ b/python/fate/ml/nn/trainer/trainer_base.py @@ -25,7 +25,7 @@ from typing import Optional import time from dataclasses import dataclass, field, fields -from transformers import trainer, trainer_callback +from transformers.trainer_callback import PrinterCallback # Reset the logger to redirect logs output @@ -108,6 +108,7 @@ def to_dict(self): @dataclass class TrainingArguments(_hf_TrainingArguments): + # in fate-2.0, we will control the output dir when using pipeline output_dir: str = field(default='./') disable_tqdm: bool = field(default=True) save_strategy: str = field(default="no") @@ -115,6 +116,8 @@ class TrainingArguments(_hf_TrainingArguments): evaluation_strategy: str = field(default="no") logging_dir: str = field(default=None) checkpoint_idx: int = field(default=None) + # by default we use constant learning rate, the same as FATE-1.X + lr_scheduler_type: str = field(default="constant") def __post_init__(self): @@ -492,6 +495,14 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: self._client_send_parameters(state, args, train_dataloader) +class FatePrinterCallback(TrainerCallback): + + def on_log(self, args, state, control, logs=None, **kwargs): + if state.is_local_process_zero: + _ = logs.pop("total_flos", None) + logger.info(str(logs)) + + class CallbackWrapper(TrainerCallback): @@ -651,12 +662,12 @@ class StdFedTrainerMixin(ShortcutCallBackInterFace, FedCallbackInterface): def __init__(self, ctx: Context, model: nn.Module, - loss_fn: nn.Module, - optimizer: torch.optim.Optimizer, training_args: TrainingArguments, fed_args: FedArguments, train_set: Dataset, val_set: Dataset = None, + loss_fn: nn.Module = None, + optimizer: torch.optim.Optimizer = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, tokenizer: Optional[PreTrainedTokenizer] = None, callbacks: Optional[List[TrainerCallback]] = [], @@ -712,6 +723,14 @@ def _add_fate_callback(self, callback_handler): # fed callback aggregator init(once), parameter check(once), # on federation of fedcallback # callbacks of shortcutcallback + new_callback_list = [] + for i in callback_handler.callbacks: + if isinstance(i, PrinterCallback): + continue + else: + new_callback_list.append(i) + new_callback_list.append(FatePrinterCallback()) + callback_handler.callbacks = new_callback_list callback_handler.callbacks.append(FedCallbackWrapper(self.ctx, self)) if self.parameter_alignment: callback_handler.callbacks.append(FedParameterAlignCallback(self, @@ -762,12 +781,12 @@ class FedTrainerClient(Trainer, StdFedTrainerMixin): def __init__(self, ctx: Context, model: nn.Module, - loss_fn: nn.Module, - optimizer: torch.optim.Optimizer, training_args: TrainingArguments, fed_args: FedArguments, train_set: Dataset, val_set: Dataset = None, + loss_fn: nn.Module = None, + optimizer: torch.optim.Optimizer = None, data_collator: Callable = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, tokenizer: Optional[PreTrainedTokenizer] = None, @@ -777,13 +796,6 @@ def __init__(self, local_mode: bool = False, parameter_alignment = True ): - - # default use no lr decay - if scheduler is None: - if use_hf_default_behavior and optimizer is None: - pass - else: - scheduler = LambdaLR(optimizer, lambda x: 1) # in case you forget to set evaluation_strategy if val_set is not None and training_args.evaluation_strategy == 'no': @@ -859,11 +871,16 @@ def prediction_step(self, if self._use_hf_default_behavior: return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) else: - with torch.no_grad(): - feats, labels = inputs - logits = model(feats) - return (None, logits, labels) - + # (features, labels), this format is used in FATE-1.x + # now the model is in eval status + if isinstance(inputs, tuple) or isinstance(inputs, list) and len(inputs) == 2: + with torch.no_grad(): + feats, labels = inputs + logits = model(feats) + return (None, logits, labels) + else: + return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) + class FedTrainerServer(object): diff --git a/python/fate/ml/utils/model_io.py b/python/fate/ml/utils/model_io.py index 898a07bffe..406037009b 100644 --- a/python/fate/ml/utils/model_io.py +++ b/python/fate/ml/utils/model_io.py @@ -1,7 +1,7 @@ from typing import Optional -class ModelExporter: +class ModelIO: _META = "meta" _DATA = "data" From b90593db391e4f3d5636189b36cb64375d71f5d9 Mon Sep 17 00:00:00 2001 From: cwj Date: Tue, 11 Jul 2023 19:32:43 +0800 Subject: [PATCH 12/61] Add Homo-LR Pipeline Component: Enable Train Signed-off-by: cwj --- homo_nn.yaml | 135 ------------------ python/fate/components/components/__init__.py | 6 + python/fate/components/components/homo_lr.py | 103 +++++++++++++ .../components/components/nn/nn_runner.py | 15 ++ .../components/nn/runner/default_runner.py | 45 ++++-- python/fate/ml/glm/homo_lr/client.py | 14 +- python/fate/ml/glm/homo_lr/server.py | 5 +- python/fate/ml/glm/homo_lr/test/local_test.py | 2 +- python/fate/ml/nn/trainer/trainer_base.py | 8 ++ 9 files changed, 177 insertions(+), 156 deletions(-) delete mode 100644 homo_nn.yaml create mode 100644 python/fate/components/components/homo_lr.py diff --git a/homo_nn.yaml b/homo_nn.yaml deleted file mode 100644 index 7ec0d60dd6..0000000000 --- a/homo_nn.yaml +++ /dev/null @@ -1,135 +0,0 @@ -component: - name: homo_nn - description: '' - provider: fate - version: 2.0.0-alpha - labels: [] - roles: - - guest - - host - - arbiter - parameters: - runner_module: - type: str - default: default_runner - optional: true - description: name of your runner script - type_meta: - title: str - type: string - default: - description: path to your runner script folder - runner_class: - type: str - default: DefaultRunner - optional: true - description: class name of your runner class - type_meta: - title: str - type: string - default: - description: path to your runner script folder - runner_conf: - type: dict - default: {} - optional: true - description: the parameter dict of the NN runner class - type_meta: - title: dict - type: object - default: {} - description: the parameter dict of the NN runner class - source: - type: str - default: - optional: true - description: path to your runner script folder - type_meta: - title: str - type: string - default: - description: path to your runner script folder - input_artifacts: - data: - train_data: - types: - - dataframe - optional: true - stages: - - train - roles: - - guest - - host - description: '' - is_multi: false - validate_data: - types: - - dataframe - optional: true - stages: - - train - roles: - - guest - - host - description: '' - is_multi: false - test_data: - types: - - dataframe - optional: true - stages: - - predict - roles: - - guest - - host - description: '' - is_multi: false - model: - model_input: - types: - - model_directory - optional: false - stages: - - predict - roles: - - guest - - host - description: '' - is_multi: false - output_artifacts: - data: - data_output: - types: - - dataframe - optional: false - stages: - - predict - - train - roles: - - guest - - host - description: '' - is_multi: false - model: - model_output: - types: - - model_directory - optional: false - stages: - - train - roles: - - guest - - host - description: '' - is_multi: false - metric: - metric: - types: - - json_metric - optional: false - stages: [] - roles: [] - description: metric, invisible for user - is_multi: false -schema_version: v1 - diff --git a/python/fate/components/components/__init__.py b/python/fate/components/components/__init__.py index ae5ea8b14d..86cafd12e2 100644 --- a/python/fate/components/components/__init__.py +++ b/python/fate/components/components/__init__.py @@ -62,6 +62,12 @@ def homo_nn(self): return homo_nn + @_lazy_cpn + def homo_lr(self): + from .homo_lr import homo_lr + + return homo_lr + @_lazy_cpn def dataframe_transformer(self): from .dataframe_transformer import dataframe_transformer diff --git a/python/fate/components/components/homo_lr.py b/python/fate/components/components/homo_lr.py new file mode 100644 index 0000000000..7dc567cc05 --- /dev/null +++ b/python/fate/components/components/homo_lr.py @@ -0,0 +1,103 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os + +import pandas as pd +from fate.arch import Context +from fate.arch.dataframe import PandasReader +from fate.ml.glm.homo_lr.client import HomoLRClient +from fate.ml.glm.homo_lr.server import HomoLRServer +from fate.components.components.utils.predict_format import LABEL +from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params +from fate.components.components.utils import consts + +logger = logging.getLogger(__name__) + + + +@cpn.component(roles=[GUEST, HOST, ARBITER]) +def homo_lr(ctx, role): + ... + + +@homo_lr.train() +def train( + ctx: Context, + role: Role, + train_data: cpn.dataframe_input(roles=[GUEST, HOST]), + validate_data: cpn.dataframe_input(roles=[GUEST, HOST], optional=True), + learning_rate_scheduler: cpn.parameter(type=params.lr_scheduler_param(), + default=params.LRSchedulerParam(method="linear", + scheduler_params={"start_factor": 1.0}), + desc="learning rate scheduler, " + "select method from {'step', 'linear', 'constant'}" + "for list of configurable arguments, " + "refer to torch.optim.lr_scheduler"), + epochs: cpn.parameter(type=params.conint(gt=0), default=20, + desc="max iteration num"), + batch_size: cpn.parameter(type=params.conint(ge=-1), default=100, + desc="batch size, " + "value less or equals to 0 means full batch"), + optimizer: cpn.parameter(type=params.optimizer_param(), + default=params.OptimizerParam(method="sgd", penalty='l2', alpha=1.0, + optimizer_params={"lr": 1e-2, "weight_decay": 0})), + init_param: cpn.parameter(type=params.init_param(), + default=params.InitParam(method='zeros', fit_intercept=True), + desc="Model param init setting."), + threshold: cpn.parameter(type=params.confloat(ge=0.0, le=1.0), default=0.5, + desc="predict threshold for binary data"), + train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), + train_input_model: cpn.json_model_input(roles=[GUEST, HOST], optional=True), + output_model: cpn.json_model_output(roles=[GUEST, HOST]) +): + + sub_ctx = ctx.sub_ctx(consts.TRAIN) + + if role.is_guest or role.is_host: # is client + + logger.info('optim param {} init param {}'.format(optimizer.dict(), init_param.dict())) + + client = HomoLRClient(epochs=epochs, batch_size=batch_size, optimizer_param=optimizer, init_param=init_param, + learning_rate_scheduler=0.01, threshold=threshold) + train_df = train_data.read() + validate_df = validate_data.read() if validate_data else None + client.fit(sub_ctx, train_df, validate_df) + output_model.write({"aaa": 1}, metadata={"bbb": 2}) + elif role.is_arbiter: # is server + logger.info('hello') + server = HomoLRServer() + server.fit(sub_ctx) + + +@homo_lr.predict() +def predict( + ctx, + role: Role, + test_data: cpn.dataframe_input(roles=[GUEST, HOST]), + batch_size: cpn.parameter(type=params.conint(ge=-1), default=100, + desc="batch size, " + "value less or equals to 0 means full batch"), + threshold: cpn.parameter(type=params.confloat(ge=0.0, le=1.0), default=0.5, + desc="predict threshold for binary data"), + predict_input_model: cpn.json_model_input(roles=[GUEST, HOST]), + test_output_data: cpn.dataframe_output(roles=[GUEST, HOST]) +): + + if role.is_guest or role.is_host: # is client + pass + + elif role.is_arbiter: # is server + logger.info("arbiter skip predict") diff --git a/python/fate/components/components/nn/nn_runner.py b/python/fate/components/components/nn/nn_runner.py index 2172cfac8a..c8d20dc54d 100644 --- a/python/fate/components/components/nn/nn_runner.py +++ b/python/fate/components/components/nn/nn_runner.py @@ -17,6 +17,10 @@ logger = logging.getLogger(__name__) +FATE_DF = 'fate_df' +STR_PATH = 'str_path' + + def _convert_to_numpy_array(data: Union[pd.Series, pd.DataFrame, np.ndarray, torch.Tensor]) -> np.ndarray: if isinstance(data, pd.Series) or isinstance(data, pd.DataFrame): return data.to_numpy() @@ -70,11 +74,17 @@ def __init__(self, train_data: Union[pd.DataFrame, str, DataFrame] = None, self.train_ids = None self.validate_ids = None self.test_ids = None + self.input_type = None + + # training + if isinstance(train_data, DataFrame): self.train_data, self.train_ids, self.schema = self._extract_fate_df(train_data) + self.input_type = FATE_DF else: self.train_data = train_data self.train_ids = SampleIDs() + self.input_type = STR_PATH if isinstance(validate_data, DataFrame): self.validate_data, self.validate_ids, _ = self._extract_fate_df(validate_data) @@ -82,11 +92,16 @@ def __init__(self, train_data: Union[pd.DataFrame, str, DataFrame] = None, self.validate_data = validate_data self.validate_ids = SampleIDs() + # prediction + if isinstance(test_data, DataFrame): self.test_data, self.test_ids, self.schema = self._extract_fate_df(test_data) + self.input_type = FATE_DF + else: self.test_data = test_data self.test_ids = SampleIDs() + self.input_type = STR_PATH self.saved_model_path = saved_model_path self.fate_save_path = fate_save_path diff --git a/python/fate/components/components/nn/runner/default_runner.py b/python/fate/components/components/nn/runner/default_runner.py index fcabc4cf22..37bb27b69d 100644 --- a/python/fate/components/components/nn/runner/default_runner.py +++ b/python/fate/components/components/nn/runner/default_runner.py @@ -11,8 +11,10 @@ from fate.ml.nn.trainer.trainer_base import FedArguments, TrainingArguments, FedTrainerClient, FedTrainerServer from typing import Union, Type, Callable, Optional from transformers.trainer_utils import get_last_checkpoint +from fate.ml.nn.dataset.table import TableDataset from typing import Literal import logging +from fate.components.components.utils import consts logger = logging.getLogger(__name__) @@ -157,14 +159,36 @@ def _loader_load_from_conf(self, conf, return_class=False): return Loader.from_dict(conf).load_item() return Loader.from_dict(conf).call_item() - def _prepare_dataset(self, dataset_conf, cpn_input_data): - dataset = self._loader_load_from_conf(dataset_conf) - if hasattr(dataset, 'load'): - if cpn_input_data is not None: - dataset.load(cpn_input_data) - return dataset + def _prepare_dataset(self, dataset_conf, cpn_input_data, schema=None): + + if cpn_input_data is None: + logger.info('input cpn data is None, return') + return + + if dataset_conf is None: + # Automatically create dataset class + label_name = None + if schema is not None: + label_name = schema.label_name + if label_name is None: + logger.info('schema is provided, but label name is None, TableDataset will automatically infer label') + else: + logger.info('schema is provided, label name is {}'.format(label_name)) + else: + logger.info('schema is not provided') + + if self.task_type == consts.MULTI: + dataset = TableDataset(label_col=label_name, flatten_label=True, label_dtype='long') else: - return None + dataset = TableDataset(label_col=label_name) + logger.info('dataset conf is not set, use default FATE Table Dataset') + + else: + dataset = self._loader_load_from_conf(dataset_conf) + + if hasattr(dataset, 'load'): + dataset.load(cpn_input_data) + return dataset else: raise ValueError(f"dataset {dataset} has no load() method") @@ -186,9 +210,10 @@ def setup(self, cpn_input_data: NNInput, stage='train'): # load arguments, models, etc # prepare datatset # dataet - train_set = self._prepare_dataset(self.dataset_conf, cpn_input_data.get_train_data()) - validate_set = self._prepare_dataset(self.dataset_conf, cpn_input_data.get_validate_data()) - test_set = self._prepare_dataset(self.dataset_conf, cpn_input_data.get_test_data()) + logger.info('NNInput data type is {}'.format(cpn_input_data.input_type)) + train_set = self._prepare_dataset(self.dataset_conf, cpn_input_data.get_train_data(), schema=cpn_input_data.get_schema()) + validate_set = self._prepare_dataset(self.dataset_conf, cpn_input_data.get_validate_data(), schema=cpn_input_data.get_schema()) + test_set = self._prepare_dataset(self.dataset_conf, cpn_input_data.get_test_data(), schema=cpn_input_data.get_schema()) # load model model = self._loader_load_from_conf(self.model_conf) if model is None: diff --git a/python/fate/ml/glm/homo_lr/client.py b/python/fate/ml/glm/homo_lr/client.py index 9b3d3d33cf..81632087bd 100644 --- a/python/fate/ml/glm/homo_lr/client.py +++ b/python/fate/ml/glm/homo_lr/client.py @@ -34,6 +34,7 @@ def from_fate_dataframe(df: DataFrame): sample_id = schema.sample_id_name match_id = schema.match_id_name label = schema.label_name + logger.info('columns are {} {} {}'.format(sample_id, match_id, label)) pd_df = df.as_pd_df() features = pd_df.drop([sample_id, match_id, label], axis=1) sample_ids = pd_df[[sample_id]] @@ -190,8 +191,8 @@ def __len__(self): class HomoLRClient(HomoModule): - def __init__(self, max_iter: int=5, batch_size: int=32, optimizer_param=None, - learning_rate_param=None, + def __init__(self, epochs: int=5, batch_size: int=32, optimizer_param=None, + learning_rate_scheduler=None, init_param=None, threshold: float=0.5 ) -> None: @@ -203,10 +204,10 @@ def __init__(self, max_iter: int=5, batch_size: int=32, optimizer_param=None, self.predict_data = None # set vars - self.max_iter = max_iter + self.max_iter = epochs self.batch_size = batch_size self.optimizer_param = optimizer_param - self.learning_rate_param = learning_rate_param + self.learning_rate_param = learning_rate_scheduler self.init_param = init_param self.threshold = threshold self.run_ovr = False @@ -288,7 +289,6 @@ def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = No per_device_train_batch_size=self.batch_size, per_gpu_eval_batch_size=self.batch_size) self.trainer = FedAVGCLient(ctx, model=self.model, loss_fn=loss_fn, optimizer=self.optimizer, train_set=train_set, val_set=validate_set, training_args=train_arg, fed_args=fed_arg, data_collator=default_data_collator) - self.trainer.set_local_mode() self.trainer.train() def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: @@ -305,8 +305,8 @@ def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: else: trainer = self.trainer predict_rs = trainer.predict(predict_set) - - return predict_rs + rs = {"predict_score": predict_rs.predictions, 'label': predict_rs.label_ids} + return rs def get_model(self) -> ModelIO: param = {} diff --git a/python/fate/ml/glm/homo_lr/server.py b/python/fate/ml/glm/homo_lr/server.py index 5f70397491..89b4ee38ad 100644 --- a/python/fate/ml/glm/homo_lr/server.py +++ b/python/fate/ml/glm/homo_lr/server.py @@ -13,14 +13,13 @@ class HomoLRServer(HomoModule): def __init__(self) -> None: pass - def fit(self, ctx: Context, data: DataFrame) -> None: + def fit(self, ctx: Context, data: DataFrame=None) -> None: - server = FedAVGServer(ctx=ctx) logger.info('server class init done, start fed training') server.train() logger.info('homo lr fit done') - def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: + def predict(self, ctx: Context, predict_data: DataFrame=None) -> DataFrame: logger.info('kkip prediction stage') \ No newline at end of file diff --git a/python/fate/ml/glm/homo_lr/test/local_test.py b/python/fate/ml/glm/homo_lr/test/local_test.py index 15810c8629..89da930ff1 100644 --- a/python/fate/ml/glm/homo_lr/test/local_test.py +++ b/python/fate/ml/glm/homo_lr/test/local_test.py @@ -33,7 +33,7 @@ data = reader.to_frame(ctx, df) df = data.as_pd_df() -client = HomoLRClient(50, 800, learning_rate_param=0.01) +client = HomoLRClient(50, 800, learning_rate_scheduler=0.01) client.l2 = 0.01 client.l1 = 0.01 client.fit(ctx, data) diff --git a/python/fate/ml/nn/trainer/trainer_base.py b/python/fate/ml/nn/trainer/trainer_base.py index 3a8a04a545..c7887ddb8b 100644 --- a/python/fate/ml/nn/trainer/trainer_base.py +++ b/python/fate/ml/nn/trainer/trainer_base.py @@ -905,6 +905,14 @@ def set_fed_context(self, ctx: Context): assert isinstance(ctx, Context), 'ctx must be a Context object, but got {}'.format(ctx) self.ctx = ctx + def set_local_mode(self): + self.local_mode = True + logger.info('trainer set to local mode') + + def set_fed_mode(self): + self.local_mode = False + logger.info('trainer set to federated mode') + def init_aggregator(self): return None From 5a2c72add61d06725b1630eb444aa96a9271a442 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Tue, 11 Jul 2023 20:15:42 +0800 Subject: [PATCH 13/61] edit selection (#4661) Signed-off-by: Yu Wu --- .../components/hetero_feature_selection.py | 5 +++++ .../hetero_feature_selection.py | 17 +++++++++-------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/python/fate/components/components/hetero_feature_selection.py b/python/fate/components/components/hetero_feature_selection.py index 6e94de4d3f..ef6a71ccd7 100644 --- a/python/fate/components/components/hetero_feature_selection.py +++ b/python/fate/components/components/hetero_feature_selection.py @@ -75,6 +75,11 @@ def train( if manual_param.keep_col is not None: keep_col = [columns[anonymous_columns.index(col)] for col in manual_param.keep_col] manual_param.keep_col = keep_col + # temp code start + iv_param = iv_param.dict() + statistic_param = statistic_param.dict() + manual_param = manual_param.dict() + # temp code end input_models = [model.read() for model in input_models] if role.is_guest: selection = HeteroSelectionModuleGuest(method, select_col, input_models, diff --git a/python/fate/ml/feature_selection/hetero_feature_selection.py b/python/fate/ml/feature_selection/hetero_feature_selection.py index 8c9d64e015..2a8558c89a 100644 --- a/python/fate/ml/feature_selection/hetero_feature_selection.py +++ b/python/fate/ml/feature_selection/hetero_feature_selection.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) -DEFAULT_METRIC = {"iv": ["iv"], "statistic": ["mean"]} +DEFAULT_METRIC = {"iv": ["iv"], "statistics": ["mean"]} class HeteroSelectionModuleGuest(HeteroModule): @@ -186,7 +186,7 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: param=self.iv_param, model=model, keep_one=self.keep_one) - elif filter_type == "statistic": + elif filter_type == "statistics": model = self.isometric_model_dict.get("statistics", None) if model is None: raise ValueError(f"Cannot find statistics model in input, please check") @@ -400,13 +400,13 @@ def fit(self, ctx: Context, train_data, validate_data=None): """metric_names = self.param.metrics or []""" # temp code ends # local only - if self.method in ["statistic"]: + if self.method in ["statistics"]: for metric_name in metric_names: - if metric_name not in self.model.get("metrics", {}): + if metric_name not in self.model.get("meta", {}).get("metrics", {}): raise ValueError(f"metric {metric_name} not found in given statistic model with metrics: " - f"{metric_names}, please check") - - metrics_all = pd.DataFrame(self.model.get("metrics_summary", {})).loc[metric_names] + f"{self.model.get('metrics', {})}, please check") + model_data = self.model.get("data", {}) + metrics_all = pd.DataFrame(model_data.get("metrics_summary", {})).loc[metric_names] self._all_metrics = metrics_all missing_col = set(self._prev_selected_mask[self._prev_selected_mask].index). \ difference(set(metrics_all.columns)) @@ -431,7 +431,8 @@ def fit(self, ctx: Context, train_data, validate_data=None): # host does not perform local iv selection if ctx.local[0] == "host": return - iv_metrics = pd.Series(self.model["metrics_summary"]["iv"]) + model_data = self.model.get("data", {}) + iv_metrics = pd.Series(model_data["metrics_summary"]["iv"]) metrics_all = pd.DataFrame(iv_metrics).T.rename({0: "iv"}, axis=0) self._all_metrics = metrics_all # works for multiple iv filters From 31021ea7aacb407c6557eb506d1f3c1a74ccd4c5 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Wed, 12 Jul 2023 11:12:42 +0800 Subject: [PATCH 14/61] edit selection example Signed-off-by: Yu Wu --- examples/pipeline/test_selection.py | 44 +++++++++++++++-------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/examples/pipeline/test_selection.py b/examples/pipeline/test_selection.py index 1cfa078484..8bc347d296 100644 --- a/examples/pipeline/test_selection.py +++ b/examples/pipeline/test_selection.py @@ -16,46 +16,42 @@ from fate_client.pipeline import FateFlowPipeline from fate_client.pipeline.components.fate import FeatureScale -from fate_client.pipeline.components.fate import Intersection from fate_client.pipeline.components.fate import Statistics, HeteroFeatureSelection from fate_client.pipeline.interface import DataWarehouseChannel pipeline = FateFlowPipeline().set_roles(guest="9999", host="9998", arbiter="9998") -intersection_0 = Intersection("intersection_0", +"""intersection_0 = Intersection("intersection_0", method="raw") intersection_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace="experiment")) + namespace="experiment_64")) intersection_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace="experiment")) + namespace="experiment_64")) intersection_1 = Intersection("intersection_1", method="raw") intersection_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace="experiment")) + namespace="experiment_64")) intersection_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace="experiment")) - + namespace="experiment_64")) +""" feature_scale_0 = FeatureScale("feature_scale_0", - method="standard", - train_data=intersection_0.outputs["output_data"]) - -feature_scale_1 = FeatureScale("feature_scale_1", - test_data=intersection_1.outputs["output_data"], - input_model=feature_scale_0.outputs["output_model"]) + method="standard") +feature_scale_0.guest.component_setting(train_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment")) +feature_scale_0.hosts[0].component_setting(train_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment")) -statistics_0 = Statistics("statistics_0", input_data=feature_scale_1.outputs["test_output_data"], +statistics_0 = Statistics("statistics_0", input_data=feature_scale_0.outputs["train_output_data"], metrics=["mean", "max", "std", "var", "kurtosis", "skewness"]) -selection_0 = HeteroFeatureSelection(name="selection_0", train_data=intersection_0.outputs["output_data"], - method=["statistic"], - input_statistic_model=statistics_0.outputs["output_model"], +selection_0 = HeteroFeatureSelection("selection_0", + train_data=feature_scale_0.outputs["train_output_data"], + method=["statistics"], + input_models=[statistics_0.outputs["output_model"]], statistic_param={"metrics": ["mean", "max", "kurtosis", "skewness"]}) -pipeline.add_task(intersection_0) -pipeline.add_task(intersection_1) pipeline.add_task(feature_scale_0) -pipeline.add_task(feature_scale_1) pipeline.add_task(statistics_0) pipeline.add_task(selection_0) pipeline.compile() @@ -65,9 +61,15 @@ print(json.dumps(pipeline.get_task_info("selection_0").get_output_model(), indent=4)) predict_pipeline = FateFlowPipeline().set_roles(guest="9999", host="9998", arbiter="9998") -pipeline.deploy([intersection_0, selection_0]) +pipeline.deploy([feature_scale_0, selection_0]) + deployed_pipeline = pipeline.get_deployed_pipeline() +deployed_pipeline.feature_scale_0.guest.component_setting(test_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment")) +deployed_pipeline.feature_scale_0.hosts[0].component_setting(test_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment")) + predict_pipeline.add_task(deployed_pipeline) print("\n\n\n") From eeba3845dd6f47e007c2cd670527d9d187fe04ff Mon Sep 17 00:00:00 2001 From: weiwee Date: Tue, 11 Jul 2023 20:11:12 +0800 Subject: [PATCH 15/61] remove exception from json restful logging Signed-off-by: weiwee --- .../components/core/component_desc/artifacts/metric/_json.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/fate/components/core/component_desc/artifacts/metric/_json.py b/python/fate/components/core/component_desc/artifacts/metric/_json.py index 1d70a7e91b..7438987e87 100644 --- a/python/fate/components/core/component_desc/artifacts/metric/_json.py +++ b/python/fate/components/core/component_desc/artifacts/metric/_json.py @@ -39,7 +39,6 @@ def write(self, data): output = requests.post(url=self.artifact.uri.original_uri, json=dict(data=[data])) except Exception as e: logger.error(f"write data `{data}` to {self.artifact.uri.original_uri} failed, error: {e}") - raise e else: logger.debug(f"write data `{data}` to {self.artifact.uri.original_uri} success, output: {output}") From 900c440c14af1dbd6d8c3e2dae057193508720ef Mon Sep 17 00:00:00 2001 From: weiwee Date: Wed, 12 Jul 2023 12:36:38 +0800 Subject: [PATCH 16/61] add quantile typing Signed-off-by: weiwee --- python/fate/arch/tensor/inside/_op_quantile.py | 2 +- rust/fate_utils/python/fate_utils/quantile.pyi | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) create mode 100644 rust/fate_utils/python/fate_utils/quantile.pyi diff --git a/python/fate/arch/tensor/inside/_op_quantile.py b/python/fate/arch/tensor/inside/_op_quantile.py index bcbf3f1d97..d790112f2d 100644 --- a/python/fate/arch/tensor/inside/_op_quantile.py +++ b/python/fate/arch/tensor/inside/_op_quantile.py @@ -24,6 +24,6 @@ def __add__(self, other: "GKSummary"): def __iadd__(self, other: torch.Tensor): if isinstance(other, torch.Tensor): - self._summary.insert_array(ohter.numpy()) + self._summary.insert_array(other.numpy()) return self return NotImplemented diff --git a/rust/fate_utils/python/fate_utils/quantile.pyi b/rust/fate_utils/python/fate_utils/quantile.pyi new file mode 100644 index 0000000000..e629e92f55 --- /dev/null +++ b/rust/fate_utils/python/fate_utils/quantile.pyi @@ -0,0 +1,14 @@ +from typing import List, Optional +import numpy as np + +class QuantileSummaryStream: + def __init__(self, epsilon: Optional[float] = None): ... + def __getstate__(self) -> bytes: ... + def __setstate__(self, state: bytes) -> None: ... + def insert_array(self, data: np.ndarray) -> None: ... + def quantile(self, phi: List[float]) -> List[float]: ... + def merge(self, other: "QuantileSummaryStream") -> "QuantileSummaryStream": ... + +def summary_f64_ix2(data: np.ndarray, epsilon: float) -> List[QuantileSummaryStream]: ... +def quantile_f64_ix1(data: np.ndarray, q: List[float], epsilon: float) -> List[float]: ... +def quantile_f64_ix2(data: np.ndarray, q: List[float], epsilon: float) -> np.ndarray: ... From 0132d44b418ead4c35127574528c7162e282fee6 Mon Sep 17 00:00:00 2001 From: weiwee Date: Wed, 12 Jul 2023 12:58:26 +0800 Subject: [PATCH 17/61] update quantile Signed-off-by: weiwee --- .../fate/arch/tensor/inside/_op_quantile.py | 46 +++++++++++++++---- .../crates/fate_utils/src/quantile.rs | 2 +- .../fate_utils/python/fate_utils/quantile.pyi | 2 +- 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/python/fate/arch/tensor/inside/_op_quantile.py b/python/fate/arch/tensor/inside/_op_quantile.py index d790112f2d..665f22dd9f 100644 --- a/python/fate/arch/tensor/inside/_op_quantile.py +++ b/python/fate/arch/tensor/inside/_op_quantile.py @@ -1,3 +1,6 @@ +from typing import List, Union + +import numpy import torch from fate_utils import quantile @@ -12,18 +15,43 @@ def quantile_fi(input: torch.Tensor, q, epsilon): class GKSummary: - def __init__(self, summary=None) -> None: - if summary is None: - summary = quantile.QuantileSummaryStream() - self._summary = summary + """ + GKSummary is a summary of a stream of numbers, which can be used to estimate quantiles. + + Examples: + >>> summary = GKSummary(0.01) + >>> summary += torch.tensor([1.0, 2.0, 3.0]) + >>> summary += torch.tensor([4.0, 5.0, 6.0]) + >>> summary += torch.tensor([7.0, 8.0, 9.0]) + >>> summary.queries([0.1, 0.2]) + [2.0, 3.0] + """ + + def __init__(self, epsilon: float) -> None: + self._summary = quantile.QuantileSummaryStream(epsilon) + + def merge(self, other: "GKSummary"): + """merge other summary into self.""" + self._summary.merge(other._summary) + return self + + def push(self, array: Union[torch.Tensor, numpy.ndarray]): + """push elements in array into summary.""" + if isinstance(array, torch.Tensor): + array = array.numpy() + self._summary.insert_array(array) + return self def __add__(self, other: "GKSummary"): if isinstance(other, GKSummary): - return GKSummary(self._summary.merge(other._summary)) + return self.merge(other) return NotImplemented - def __iadd__(self, other: torch.Tensor): - if isinstance(other, torch.Tensor): - self._summary.insert_array(other.numpy()) - return self + def __iadd__(self, other: Union[torch.Tensor, numpy.ndarray]): + if isinstance(other, torch.Tensor) or isinstance(other, numpy.ndarray): + return self.push(other) return NotImplemented + + def queries(self, q: List[float]): + """return quantile values of q.""" + return torch.tensor(self._summary.queries(q)) diff --git a/rust/fate_utils/crates/fate_utils/src/quantile.rs b/rust/fate_utils/crates/fate_utils/src/quantile.rs index 482ca06f60..14f8b802c1 100644 --- a/rust/fate_utils/crates/fate_utils/src/quantile.rs +++ b/rust/fate_utils/crates/fate_utils/src/quantile.rs @@ -71,7 +71,7 @@ impl QuantileSummaryStream { } Ok(()) } - pub fn quantile(&self, phi: Vec) -> Vec { + pub fn queries(&self, phi: Vec) -> Vec { phi.iter() .map(|p| self.0.as_ref().unwrap().quantile(*p).0) .collect() diff --git a/rust/fate_utils/python/fate_utils/quantile.pyi b/rust/fate_utils/python/fate_utils/quantile.pyi index e629e92f55..7b18ce3116 100644 --- a/rust/fate_utils/python/fate_utils/quantile.pyi +++ b/rust/fate_utils/python/fate_utils/quantile.pyi @@ -6,7 +6,7 @@ class QuantileSummaryStream: def __getstate__(self) -> bytes: ... def __setstate__(self, state: bytes) -> None: ... def insert_array(self, data: np.ndarray) -> None: ... - def quantile(self, phi: List[float]) -> List[float]: ... + def queries(self, phi: List[float]) -> List[float]: ... def merge(self, other: "QuantileSummaryStream") -> "QuantileSummaryStream": ... def summary_f64_ix2(data: np.ndarray, epsilon: float) -> List[QuantileSummaryStream]: ... From 3eb7fbdd60482ebb2e5b27102bec6665d40f39e4 Mon Sep 17 00:00:00 2001 From: weiwee Date: Wed, 12 Jul 2023 13:23:30 +0800 Subject: [PATCH 18/61] update quantile Signed-off-by: weiwee --- .../fate/arch/tensor/inside/_op_quantile.py | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/python/fate/arch/tensor/inside/_op_quantile.py b/python/fate/arch/tensor/inside/_op_quantile.py index 665f22dd9f..9c975c04c0 100644 --- a/python/fate/arch/tensor/inside/_op_quantile.py +++ b/python/fate/arch/tensor/inside/_op_quantile.py @@ -19,27 +19,36 @@ class GKSummary: GKSummary is a summary of a stream of numbers, which can be used to estimate quantiles. Examples: - >>> summary = GKSummary(0.01) + >>> summary = GKSummary(0.001) >>> summary += torch.tensor([1.0, 2.0, 3.0]) >>> summary += torch.tensor([4.0, 5.0, 6.0]) - >>> summary += torch.tensor([7.0, 8.0, 9.0]) - >>> summary.queries([0.1, 0.2]) - [2.0, 3.0] + >>> summary2 = GKSummary(0.001) + >>> summary2 += torch.tensor([7.0, 8.0, 9.0, 10.0]) + >>> summary = summary + summary2 + >>> summary.queries([0.1, 0.2, 0.7, 0.8]) + [1.0, 2.0, 7.0, 8.0] """ def __init__(self, epsilon: float) -> None: - self._summary = quantile.QuantileSummaryStream(epsilon) + self._epsilon = epsilon + self._summary = None + + def _get_summary(self): + if self._summary is None: + self._summary = quantile.QuantileSummaryStream(self._epsilon) + return self._summary def merge(self, other: "GKSummary"): """merge other summary into self.""" - self._summary.merge(other._summary) - return self + gk = GKSummary(self._epsilon) + gk._summary = self._get_summary().merge(other._get_summary()) + return gk def push(self, array: Union[torch.Tensor, numpy.ndarray]): """push elements in array into summary.""" if isinstance(array, torch.Tensor): array = array.numpy() - self._summary.insert_array(array) + self._get_summary().insert_array(array.astype(numpy.float64)) return self def __add__(self, other: "GKSummary"): @@ -52,6 +61,6 @@ def __iadd__(self, other: Union[torch.Tensor, numpy.ndarray]): return self.push(other) return NotImplemented - def queries(self, q: List[float]): + def queries(self, q: List[float]) -> List[float]: """return quantile values of q.""" - return torch.tensor(self._summary.queries(q)) + return self._get_summary().queries(q) From 55142bedc8fc68bf116cf96faf62ceeab557d649 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Wed, 12 Jul 2023 14:39:40 +0800 Subject: [PATCH 19/61] dataframe: fix setitem when rhs is str Signed-off-by: mgqa34 --- python/fate/arch/dataframe/ops/_set_item.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/fate/arch/dataframe/ops/_set_item.py b/python/fate/arch/dataframe/ops/_set_item.py index 7372e88b04..a8dd3a84da 100644 --- a/python/fate/arch/dataframe/ops/_set_item.py +++ b/python/fate/arch/dataframe/ops/_set_item.py @@ -96,7 +96,7 @@ def _append_tensor(l_blocks, r_tensor, bid_list=None, dm: DataManager=None): return ret_blocks data_manager = df.data_manager - if isinstance(items, (bool, int, float, np.int32, np.float32, np.int64, np.float64, np.bool)): + if isinstance(items, (bool, int, float, str, np.int32, np.float32, np.int64, np.float64, np.bool)): bids = data_manager.append_columns(keys, BlockType.get_block_type(items)) _append_func = functools.partial(_append_single, item=items, col_len=len(keys), bid=bids[0], dm=data_manager) block_table = df.block_table.mapValues(_append_func) @@ -193,7 +193,7 @@ def _replace_tensor(blocks, r_tensor, narrow_loc=None, dst_bids=None, dm: DataMa return ret_blocks data_manager = df.data_manager - if isinstance(items, (bool, int, float, np.int32, np.float32, np.int64, np.float64, np.bool)): + if isinstance(items, (bool, int, float, str, np.int32, np.float32, np.int64, np.float64, np.bool)): narrow_blocks, dst_blocks = data_manager.split_columns(keys, BlockType.get_block_type(items)) replace_func = functools.partial(_replace_single, item=items, narrow_loc=narrow_blocks, dst_bids=dst_blocks, dm=data_manager) From efdd0643d34723c6b016f4336fa3597653b03d26 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Wed, 12 Jul 2023 14:57:05 +0800 Subject: [PATCH 20/61] dataloader: add batch encoding Signed-off-by: mgqa34 --- python/fate/arch/dataframe/__init__.py | 3 +- python/fate/arch/dataframe/utils/__init__.py | 1 + .../fate/arch/dataframe/utils/_dataloader.py | 59 +++++++++++-------- 3 files changed, 38 insertions(+), 25 deletions(-) diff --git a/python/fate/arch/dataframe/__init__.py b/python/fate/arch/dataframe/__init__.py index 7e2a7d1dd2..f0a0b1ed66 100644 --- a/python/fate/arch/dataframe/__init__.py +++ b/python/fate/arch/dataframe/__init__.py @@ -21,7 +21,7 @@ TorchDataSetReader, ) from .io import build_schema, deserialize, parse_schema, serialize -from .utils import DataLoader +from .utils import DataLoader, BatchEncoding from .utils import KFold __all__ = [ @@ -37,4 +37,5 @@ "DataFrame", "KFold", "DataLoader", + "BatchEncoding" ] diff --git a/python/fate/arch/dataframe/utils/__init__.py b/python/fate/arch/dataframe/utils/__init__.py index 81e6eb9b51..9a0b6c5b07 100644 --- a/python/fate/arch/dataframe/utils/__init__.py +++ b/python/fate/arch/dataframe/utils/__init__.py @@ -13,4 +13,5 @@ # See the License for the specific language governing permissions and # limitations under the License. from ._dataloader import DataLoader +from ._dataloader import BatchEncoding from ._k_fold import KFold diff --git a/python/fate/arch/dataframe/utils/_dataloader.py b/python/fate/arch/dataframe/utils/_dataloader.py index 59c873094d..119e0b18bf 100644 --- a/python/fate/arch/dataframe/utils/_dataloader.py +++ b/python/fate/arch/dataframe/utils/_dataloader.py @@ -63,13 +63,6 @@ def _init_settings(self): else: raise ValueError(f"batch strategy {self._batch_strategy} is not support") - def next_batch(self, with_index=True): - batch = next(self._batch_generator) - if with_index: - return batch - else: - return batch[1:] - @staticmethod def batch_num(self): return self._batch_generator.batch_num @@ -154,31 +147,49 @@ def _prepare(self): def __next__(self): if self._role == "arbiter": for batch_id in range(self._batch_num): - yield batch_id, batch_id + yield BatchEncoding(batch_id=batch_id) return - for batch in self._batch_splits: + for bid, batch in enumerate(self._batch_splits): if batch.label and batch.weight: - yield batch.values.as_tensor(), batch.label.as_tensr(), batch.weight.as_tensor() + yield BatchEncoding(x=batch.values.as_tensor(), + label=batch.label.as_tensor(), + weight=batch.weight.as_tensor(), + batch_id=bid) elif batch.label: - yield batch.values.as_tensor(), batch.label.as_tensor() + yield BatchEncoding(x=batch.values.as_tensor(), + label=batch.label.as_tensor(), + batch_id=bid) else: - yield batch.values.as_tensor() + yield BatchEncoding(x=batch.values.as_tensor()) def __iter__(self): - if self._role == "arbiter": - for batch_id in range(self._batch_num): - yield batch_id, batch_id - return - - for batch in self._batch_splits: - if batch.label and batch.weight: - yield batch.values.as_tensor(), batch.label.as_tensor(), batch.weight.as_tensor() - elif batch.label: - yield batch.values.as_tensor(), batch.label.as_tensor() - else: - yield batch.values.as_tensor() + return self.__next__() @property def batch_num(self): return self._batch_num + + +class BatchEncoding(object): + def __init__(self, x=None, label=None, weight=None, batch_id=None): + self._x = x + self._label = label + self._weight = weight + self._batch_id = batch_id + + @property + def x(self): + return self._x + + @property + def label(self): + return self._label + + @property + def weight(self): + return self._weight + + @property + def batch_id(self): + return self._batch_id From d76a7116b95ec057380130833058d1d28b4bc6af Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Wed, 12 Jul 2023 15:13:48 +0800 Subject: [PATCH 21/61] group coordinated lr & linr into hetero glm (#4659) use new batch loader format(#4659) fix optimizer state dict serialization (#4659) Signed-off-by: Yu Wu --- .../components/components/coordinated_linr.py | 10 +++++----- python/fate/ml/glm/__init__.py | 4 ++-- python/fate/ml/glm/hetero/__init__.py | 0 .../{ => hetero}/coordinated_linr/__init__.py | 0 .../{ => hetero}/coordinated_linr/arbiter.py | 0 .../{ => hetero}/coordinated_linr/guest.py | 19 +++++++++--------- .../glm/{ => hetero}/coordinated_linr/host.py | 3 ++- .../{ => hetero}/coordinated_lr/__init__.py | 0 .../{ => hetero}/coordinated_lr/arbiter.py | 0 .../glm/{ => hetero}/coordinated_lr/guest.py | 20 +++++++++---------- .../glm/{ => hetero}/coordinated_lr/host.py | 3 ++- python/fate/ml/utils/_optimizer.py | 14 +++++++++++-- 12 files changed, 41 insertions(+), 32 deletions(-) create mode 100644 python/fate/ml/glm/hetero/__init__.py rename python/fate/ml/glm/{ => hetero}/coordinated_linr/__init__.py (100%) rename python/fate/ml/glm/{ => hetero}/coordinated_linr/arbiter.py (100%) rename python/fate/ml/glm/{ => hetero}/coordinated_linr/guest.py (92%) rename python/fate/ml/glm/{ => hetero}/coordinated_linr/host.py (98%) rename python/fate/ml/glm/{ => hetero}/coordinated_lr/__init__.py (100%) rename python/fate/ml/glm/{ => hetero}/coordinated_lr/arbiter.py (100%) rename python/fate/ml/glm/{ => hetero}/coordinated_lr/guest.py (96%) rename python/fate/ml/glm/{ => hetero}/coordinated_lr/host.py (98%) diff --git a/python/fate/components/components/coordinated_linr.py b/python/fate/components/components/coordinated_linr.py index ee9093b774..fbdc0ebcf0 100644 --- a/python/fate/components/components/coordinated_linr.py +++ b/python/fate/components/components/coordinated_linr.py @@ -95,7 +95,7 @@ def predict( def train_guest(ctx, train_data, validate_data, train_output_data, output_model, epochs, batch_size, optimizer_param, learning_rate_param, init_param): logger.info(f"coordinated linr guest start train") - from fate.ml.glm.coordinated_linr import CoordinatedLinRModuleGuest + from fate.ml.glm import CoordinatedLinRModuleGuest # optimizer = optimizer_factory(optimizer_param) sub_ctx = ctx.sub_ctx("train") module = CoordinatedLinRModuleGuest(epochs=epochs, batch_size=batch_size, @@ -126,7 +126,7 @@ def train_host(ctx, train_data, validate_data, train_output_data, output_model, optimizer_param, learning_rate_param, init_param): logger.info(f"coordinated linr host start train") - from fate.ml.glm.coordinated_linr import CoordinatedLinRModuleHost + from fate.ml.glm import CoordinatedLinRModuleHost # optimizer = optimizer_factory(optimizer_param) sub_ctx = ctx.sub_ctx("train") @@ -151,7 +151,7 @@ def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_param, output_model): logger.info(f"coordinated linr arbiter start train") - from fate.ml.glm.coordinated_linr import CoordinatedLinRModuleArbiter + from fate.ml.glm import CoordinatedLinRModuleArbiter sub_ctx = ctx.sub_ctx("train") module = CoordinatedLinRModuleArbiter(epochs=epochs, early_stop=early_stop, tol=tol, batch_size=batch_size, @@ -165,7 +165,7 @@ def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, def predict_guest(ctx, input_model, test_data, test_output_data): - from fate.ml.glm.coordinated_linr import CoordinatedLinRModuleGuest + from fate.ml.glm import CoordinatedLinRModuleGuest sub_ctx = ctx.sub_ctx("predict") model = input_model.read() @@ -178,7 +178,7 @@ def predict_guest(ctx, input_model, test_data, test_output_data): def predict_host(ctx, input_model, test_data, test_output_data): - from fate.ml.glm.coordinated_linr import CoordinatedLinRModuleHost + from fate.ml.glm import CoordinatedLinRModuleHost sub_ctx = ctx.sub_ctx("predict") model = input_model.read() diff --git a/python/fate/ml/glm/__init__.py b/python/fate/ml/glm/__init__.py index 299acc3dcb..f508993cb8 100644 --- a/python/fate/ml/glm/__init__.py +++ b/python/fate/ml/glm/__init__.py @@ -1,2 +1,2 @@ -from .coordinated_linr import CoordinatedLinRModuleHost, CoordinatedLinRModuleGuest, CoordinatedLinRModuleArbiter -from .coordinated_lr import CoordinatedLRModuleHost, CoordinatedLRModuleGuest, CoordinatedLRModuleArbiter +from .hetero.coordinated_linr import CoordinatedLinRModuleHost, CoordinatedLinRModuleGuest, CoordinatedLinRModuleArbiter +from .hetero.coordinated_lr import CoordinatedLRModuleHost, CoordinatedLRModuleGuest, CoordinatedLRModuleArbiter diff --git a/python/fate/ml/glm/hetero/__init__.py b/python/fate/ml/glm/hetero/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/fate/ml/glm/coordinated_linr/__init__.py b/python/fate/ml/glm/hetero/coordinated_linr/__init__.py similarity index 100% rename from python/fate/ml/glm/coordinated_linr/__init__.py rename to python/fate/ml/glm/hetero/coordinated_linr/__init__.py diff --git a/python/fate/ml/glm/coordinated_linr/arbiter.py b/python/fate/ml/glm/hetero/coordinated_linr/arbiter.py similarity index 100% rename from python/fate/ml/glm/coordinated_linr/arbiter.py rename to python/fate/ml/glm/hetero/coordinated_linr/arbiter.py diff --git a/python/fate/ml/glm/coordinated_linr/guest.py b/python/fate/ml/glm/hetero/coordinated_linr/guest.py similarity index 92% rename from python/fate/ml/glm/coordinated_linr/guest.py rename to python/fate/ml/glm/hetero/coordinated_linr/guest.py index d0a89c4b56..fed5c8c50b 100644 --- a/python/fate/ml/glm/coordinated_linr/guest.py +++ b/python/fate/ml/glm/hetero/coordinated_linr/guest.py @@ -43,7 +43,6 @@ def __init__( self.estimator = None def fit(self, ctx: Context, train_data, validate_data=None) -> None: - with_weight = train_data.weight is not None optimizer = Optimizer( self.optimizer_param["method"], self.optimizer_param["penalty"], @@ -57,7 +56,7 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: optimizer=optimizer, learning_rate_scheduler=lr_scheduler, init_param=self.init_param) - estimator.fit_model(ctx, train_data, validate_data, with_weight=with_weight) + estimator.fit_model(ctx, train_data, validate_data) self.estimator = estimator def predict(self, ctx, test_data): @@ -107,7 +106,7 @@ def __init__( self.end_epoch = -1 self.is_converged = False - def fit_model(self, ctx, train_data, validate_data=None, with_weight=False): + def fit_model(self, ctx, train_data, validate_data=None): coef_count = train_data.shape[1] if self.init_param.get("fit_intercept"): train_data["intercept"] = 1 @@ -117,8 +116,7 @@ def fit_model(self, ctx, train_data, validate_data=None, with_weight=False): self.optimizer.init_optimizer(model_parameter_length=w.size()[0]) self.lr_scheduler.init_scheduler(optimizer=self.optimizer.optimizer) batch_loader = dataframe.DataLoader( - train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="guest", sync_arbiter=True, - # with_weight=True + train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="guest", sync_arbiter=True ) if self.end_epoch >= 0: self.start_epoch = self.end_epoch + 1 @@ -126,9 +124,10 @@ def fit_model(self, ctx, train_data, validate_data=None, with_weight=False): for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): self.optimizer.set_iters(i) logger.info(f"self.optimizer set epoch {i}") - # todo: if self.with_weight: include weight in batch result - # for batch_ctx, (X, Y, weight) in iter_ctx.iter(batch_loader): - for batch_ctx, (X, Y) in iter_ctx.on_batches.ctxs_zip(batch_loader): + for batch_ctx, batch_data in iter_ctx.on_batches.ctxs_zip(batch_loader): + X = batch_data.x + Y = batch_data.label + weight = batch_data.weight h = X.shape[0] Xw = torch.matmul(X, w.detach()) d = Xw - Y @@ -141,8 +140,8 @@ def fit_model(self, ctx, train_data, validate_data=None, with_weight=False): d += Xw_h loss += 1 / h * torch.matmul(Xw.T, Xw_h) - # if with_weight: - # d = d * weight + if weight: + d = d * weight batch_ctx.hosts.put(d=d) for Xw2_h in batch_ctx.hosts.get("Xw2_h"): diff --git a/python/fate/ml/glm/coordinated_linr/host.py b/python/fate/ml/glm/hetero/coordinated_linr/host.py similarity index 98% rename from python/fate/ml/glm/coordinated_linr/host.py rename to python/fate/ml/glm/hetero/coordinated_linr/host.py index 8dd203a01c..5b3807b350 100644 --- a/python/fate/ml/glm/coordinated_linr/host.py +++ b/python/fate/ml/glm/hetero/coordinated_linr/host.py @@ -121,7 +121,8 @@ def fit_model(self, ctx: Context, encryptor, train_data, validate_data=None) -> for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): self.optimizer.set_iters(i) logger.info(f"self.optimizer set epoch {i}") - for batch_ctx, X in iter_ctx.on_batches.ctxs_zip(batch_loader): + for batch_ctx, batch_data in iter_ctx.on_batches.ctxs_zip(batch_loader): + X = batch_data.x h = X.shape[0] Xw_h = torch.matmul(X, w.detach()) batch_ctx.guest.put("Xw_h", encryptor.encrypt(Xw_h)) diff --git a/python/fate/ml/glm/coordinated_lr/__init__.py b/python/fate/ml/glm/hetero/coordinated_lr/__init__.py similarity index 100% rename from python/fate/ml/glm/coordinated_lr/__init__.py rename to python/fate/ml/glm/hetero/coordinated_lr/__init__.py diff --git a/python/fate/ml/glm/coordinated_lr/arbiter.py b/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py similarity index 100% rename from python/fate/ml/glm/coordinated_lr/arbiter.py rename to python/fate/ml/glm/hetero/coordinated_lr/arbiter.py diff --git a/python/fate/ml/glm/coordinated_lr/guest.py b/python/fate/ml/glm/hetero/coordinated_lr/guest.py similarity index 96% rename from python/fate/ml/glm/coordinated_lr/guest.py rename to python/fate/ml/glm/hetero/coordinated_lr/guest.py index 8ae83ea756..9de4817835 100644 --- a/python/fate/ml/glm/coordinated_lr/guest.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/guest.py @@ -53,7 +53,6 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: ctx.arbiter.put("label_count", label_count) ctx.hosts.put("label_count", label_count) self.labels = [label_name.split("_")[1] for label_name in train_data_binarized_label.columns] - with_weight = train_data.weight is not None if label_count > 2: logger.info(f"OVR data provided, will train OVR models.") self.ovr = True @@ -77,7 +76,7 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: init_param=self.init_param, ) train_data.label = train_data_binarized_label[train_data_binarized_label.columns[i]] - single_estimator.fit_single_model(class_ctx, train_data, validate_data, with_weight=with_weight) + single_estimator.fit_single_model(class_ctx, train_data, validate_data) self.estimator[i] = single_estimator else: optimizer = Optimizer( @@ -95,7 +94,7 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: learning_rate_scheduler=lr_scheduler, init_param=self.init_param, ) - single_estimator.fit_single_model(ctx, train_data, validate_data, with_weight=with_weight) + single_estimator.fit_single_model(ctx, train_data, validate_data) self.estimator = single_estimator train_data.label = original_label @@ -167,9 +166,8 @@ def __init__(self, epochs=None, batch_size=None, optimizer=None, learning_rate_s self.start_epoch = 0 self.end_epoch = -1 self.is_converged = False - self.with_weight = False - def fit_single_model(self, ctx: Context, train_data, validate_data=None, with_weight=False): + def fit_single_model(self, ctx: Context, train_data, validate_data=None): """ l(w) = 1/h * Σ(log(2) - 0.5 * y * xw + 0.125 * (wx)^2) ∇l(w) = 1/h * Σ(0.25 * xw - 0.5 * y)x = 1/h * Σdx @@ -195,14 +193,14 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None, with_we ) if self.end_epoch >= 0: self.start_epoch = self.end_epoch + 1 - """if train_data.weight: - self.with_weight = True""" for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): self.optimizer.set_iters(i) logger.info(f"self.optimizer set epoch {i}") - # todo: if self.with_weight: include weight in batch result - for batch_ctx, (X, Y) in iter_ctx.on_batches.ctxs_zip(batch_loader): + for batch_ctx, batch_data in iter_ctx.on_batches.ctxs_zip(batch_loader): + X = batch_data.x + Y = batch_data.label + weight = batch_data.weight h = X.shape[0] # logger.info(f"h: {h}") @@ -219,8 +217,8 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None, with_we d += Xw_h loss -= 0.5 / h * torch.matmul(Y.T, Xw_h) loss += 0.25 / h * torch.matmul(Xw.T, Xw_h) - # if with_weight: - # d = d * weight + if weight: + d = d * weight batch_ctx.hosts.put(d=d) for Xw2_h in batch_ctx.hosts.get("Xw2_h"): diff --git a/python/fate/ml/glm/coordinated_lr/host.py b/python/fate/ml/glm/hetero/coordinated_lr/host.py similarity index 98% rename from python/fate/ml/glm/coordinated_lr/host.py rename to python/fate/ml/glm/hetero/coordinated_lr/host.py index 09662e7b28..5ff4a8e024 100644 --- a/python/fate/ml/glm/coordinated_lr/host.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/host.py @@ -160,7 +160,8 @@ def fit_single_model(self, ctx: Context, encryptor, train_data, validate_data=No for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): self.optimizer.set_iters(i) logger.info(f"self.optimizer set epoch{i}") - for batch_ctx, X in iter_ctx.on_batches.ctxs_zip(batch_loader): + for batch_ctx, batch_data in iter_ctx.on_batches.ctxs_zip(batch_loader): + X = batch_data.x h = X.shape[0] Xw_h = 0.25 * torch.matmul(X, w.detach()) batch_ctx.guest.put("Xw_h", encryptor.encrypt(Xw_h)) diff --git a/python/fate/ml/utils/_optimizer.py b/python/fate/ml/utils/_optimizer.py index 3cd578b168..c586443523 100644 --- a/python/fate/ml/utils/_optimizer.py +++ b/python/fate/ml/utils/_optimizer.py @@ -103,23 +103,33 @@ def shrinkage_val(self, lr): return self.alpha * this_step_size def state_dict(self): + optimizer_state_dict = self.optimizer.state_dict() + state_all = optimizer_state_dict['state'].get(0, {}) + for k, v in state_all.items(): + if isinstance(v, torch.Tensor): + state_all[k] = v.tolist() return { "l2_penalty": self.l2_penalty, "l1_penalty": self.l1_penalty, "alpha": self.alpha, - "optimizer": self.optimizer.state_dict(), + "optimizer": optimizer_state_dict, "method": self.method, "optim_param": self.optim_param, "model_parameter": self.model_parameter.tolist() } - def load_state_dict(self, dict, model_parameter=None): + def load_state_dict(self, dict): self.l2_penalty = dict["l2_penalty"] self.l1_penalty = dict["l1_penalty"] self.alpha = dict["alpha"] self.method = dict["method"] self.optim_param = dict["optim_param"] self.init_optimizer(model_parameter=torch.nn.parameter.Parameter(torch.tensor(dict["model_parameter"]))) + state_dict = dict["optimizer"] + state_all = state_dict['state'].get(0, {}) + for k, v in state_all.items(): + if isinstance(v, list): + state_all[k] = torch.tensor(v) self.optimizer.load_state_dict(dict["optimizer"]) def set_iters(self, new_iters): From a49706649d22942d70b1a6b80c912957afd425df Mon Sep 17 00:00:00 2001 From: cwj Date: Wed, 12 Jul 2023 15:29:38 +0800 Subject: [PATCH 22/61] Update Homo-LR: 1. None label predict 2. Training check 3. Predict support in component Update FATE-ML 1. Add tools for predict result formatting Signed-off-by: cwj --- python/fate/components/components/homo_lr.py | 40 ++++++++++++++------ 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/python/fate/components/components/homo_lr.py b/python/fate/components/components/homo_lr.py index 7dc567cc05..618b74c1d0 100644 --- a/python/fate/components/components/homo_lr.py +++ b/python/fate/components/components/homo_lr.py @@ -13,16 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import os - -import pandas as pd from fate.arch import Context -from fate.arch.dataframe import PandasReader from fate.ml.glm.homo_lr.client import HomoLRClient from fate.ml.glm.homo_lr.server import HomoLRServer -from fate.components.components.utils.predict_format import LABEL from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params from fate.components.components.utils import consts +from fate.ml.utils.model_io import ModelIO + logger = logging.getLogger(__name__) @@ -59,25 +56,40 @@ def train( desc="Model param init setting."), threshold: cpn.parameter(type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data"), + ovr: cpn.parameter(type=bool, default=False, + desc="predict threshold for binary data"), + label_num: cpn.parameter(type=params.conint(ge=2), default=None), train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), train_input_model: cpn.json_model_input(roles=[GUEST, HOST], optional=True), - output_model: cpn.json_model_output(roles=[GUEST, HOST]) + train_output_model: cpn.json_model_output(roles=[GUEST, HOST]) ): sub_ctx = ctx.sub_ctx(consts.TRAIN) if role.is_guest or role.is_host: # is client + logger.info('homo lr component: client start training') logger.info('optim param {} init param {}'.format(optimizer.dict(), init_param.dict())) - client = HomoLRClient(epochs=epochs, batch_size=batch_size, optimizer_param=optimizer, init_param=init_param, - learning_rate_scheduler=0.01, threshold=threshold) + client = HomoLRClient(epochs=epochs, batch_size=batch_size, optimizer_param=optimizer.dict(), init_param=init_param.dict(), + learning_rate_scheduler=0.01, threshold=threshold, ovr=ovr, label_num=label_num) train_df = train_data.read() validate_df = validate_data.read() if validate_data else None client.fit(sub_ctx, train_df, validate_df) - output_model.write({"aaa": 1}, metadata={"bbb": 2}) + model_dict = client.get_model().dict() + + train_rs = client.predict(sub_ctx, train_df) + if validate_df: + validate_rs = client.predict(sub_ctx, validate_df) + ret_df = train_rs.vstack(validate_rs) + else: + ret_df = train_rs + + train_output_data.write(ret_df) + train_output_model.write(model_dict, metadata=model_dict['meta']) + elif role.is_arbiter: # is server - logger.info('hello') + logger.info('homo lr component: server start training') server = HomoLRServer() server.fit(sub_ctx) @@ -97,7 +109,13 @@ def predict( ): if role.is_guest or role.is_host: # is client - pass + + client = HomoLRClient(batch_size=batch_size, threshold=threshold) + model_input = predict_input_model.read() + model_data = ModelIO.from_dict(model_input) + logger.info('model input is {}'.format(model_input)) + pred_rs = client.predict(ctx, test_data.read()) + elif role.is_arbiter: # is server logger.info("arbiter skip predict") From 6639af63c05211c7e14de3ee8e2fdc004421e19e Mon Sep 17 00:00:00 2001 From: cwj Date: Wed, 12 Jul 2023 15:30:08 +0800 Subject: [PATCH 23/61] Update Signed-off-by: cwj --- .../fate/components/components/evaluation.py | 7 +- .../fate/components/components/utils/tools.py | 16 +++ python/fate/ml/evaluation/classification.py | 2 +- python/fate/ml/glm/homo_lr/client.py | 107 +++++++++++++++--- python/fate/ml/glm/homo_lr/test/local_test.py | 12 +- python/fate/ml/utils/predict_format.py | 84 ++++++++++++++ 6 files changed, 198 insertions(+), 30 deletions(-) create mode 100644 python/fate/components/components/utils/tools.py create mode 100644 python/fate/ml/utils/predict_format.py diff --git a/python/fate/components/components/evaluation.py b/python/fate/components/components/evaluation.py index 6012b5af87..3c560276e2 100644 --- a/python/fate/components/components/evaluation.py +++ b/python/fate/components/components/evaluation.py @@ -79,12 +79,11 @@ def evaluate(input_data, metrics): data = input_data.as_pd_df() split_dict = split_dataframe_by_type(data) rs_dict = {} - logger.info('eval dataframe is {}'.format(data)) - + for name, df in split_dict.items(): - - logger.info('eval dataframe is {}'.format(df)) + + logger.info('eval dataframe is \n\n{}'.format(df)) y_true = df[LABEL] # in case is multi result, use tolist y_pred = df[PREDICT_SCORE] diff --git a/python/fate/components/components/utils/tools.py b/python/fate/components/components/utils/tools.py new file mode 100644 index 0000000000..650f330001 --- /dev/null +++ b/python/fate/components/components/utils/tools.py @@ -0,0 +1,16 @@ +from fate.arch.dataframe import DataFrame +from .consts import TRAIN_SET, VALIDATE_SET, TESET_SET + + +def cat_train_and_validate_df(train_df: DataFrame, val_df: DataFrame): + """ + Concatenate train and validate dataframe + """ + return train_df.vstack(val_df) + + +def add_dataset_type(df: DataFrame, dataset_type): + assert dataset_type in [TRAIN_SET, VALIDATE_SET, TESET_SET], f"dataset_type must be one of {TRAIN_SET}, {VALIDATE_SET}, {TESET_SET}" + return df + + diff --git a/python/fate/ml/evaluation/classification.py b/python/fate/ml/evaluation/classification.py index c7ac33372b..9031cbb9ac 100644 --- a/python/fate/ml/evaluation/classification.py +++ b/python/fate/ml/evaluation/classification.py @@ -45,7 +45,7 @@ def __call__(self, predict, label, **kwargs) -> Dict: predict = self.to_np_format(predict, flatten=False) label = self.to_np_format(label).astype(np.int32) if predict.shape != label.shape: - predict = predict.argmax(axis=-1) + predict = predict.argmax(axis=-1).astype(np.int32) acc = accuracy_score(label, predict) return EvalResult(self.metric_name, acc) diff --git a/python/fate/ml/glm/homo_lr/client.py b/python/fate/ml/glm/homo_lr/client.py index 81632087bd..5920af2df6 100644 --- a/python/fate/ml/glm/homo_lr/client.py +++ b/python/fate/ml/glm/homo_lr/client.py @@ -14,6 +14,8 @@ import functools import tempfile from torch.utils.data import Dataset +from fate.ml.utils.predict_format import std_output_df, add_ids, to_fate_df +from fate.ml.utils.predict_format import MULTI, BINARY logger = logging.getLogger(__name__) @@ -28,18 +30,31 @@ def __init__(self, features: pd.DataFrame, sample_ids: pd.DataFrame, match_ids: self.match_ids = match_ids self.labels = labels + def get_match_id_name(self): + return self.match_ids.columns[0] + + def get_sample_id_name(self): + return self.sample_ids.columns[0] + + def has_label(self): + return self.labels is not None + @staticmethod def from_fate_dataframe(df: DataFrame): schema = df.schema sample_id = schema.sample_id_name match_id = schema.match_id_name label = schema.label_name - logger.info('columns are {} {} {}'.format(sample_id, match_id, label)) pd_df = df.as_pd_df() - features = pd_df.drop([sample_id, match_id, label], axis=1) + if label is None: + labels = None + features = pd_df.drop([sample_id, match_id], axis=1) + else: + labels = pd_df[[label]] + features = pd_df.drop([sample_id, match_id, label], axis=1) sample_ids = pd_df[[sample_id]] match_ids = pd_df[[match_id]] - labels = pd_df[[label]] + return Data(features, sample_ids, match_ids, labels) @@ -178,23 +193,31 @@ class DictDataset(Dataset): """ def __init__(self, data): self.X = np.array(data.features.values).astype(np.float32) - self.y = np.array(data.labels.values).astype(np.float32) self.X_tensor = t.tensor(self.X, dtype=t.float32) - self.y_tensor = t.tensor(self.y.reshape((-1, 1)), dtype=t.float32) - + if data.labels is None: + self.y = None + else: + self.y = np.array(data.labels.values).astype(np.float32) + self.y_tensor = t.tensor(self.y.reshape((-1, 1)), dtype=t.float32) + def __getitem__(self, index): - return {'x': self.X_tensor[index], 'label': self.y_tensor[index]} + if self.y is not None: + return {'x': self.X_tensor[index], 'label': self.y_tensor[index]} + else: + return {'x': self.X_tensor[index]} def __len__(self): return self.X_tensor.shape[0] - + class HomoLRClient(HomoModule): def __init__(self, epochs: int=5, batch_size: int=32, optimizer_param=None, learning_rate_scheduler=None, init_param=None, - threshold: float=0.5 + threshold: float=0.5, + ovr=False, + label_num=None, ) -> None: super().__init__() @@ -213,6 +236,12 @@ def __init__(self, epochs: int=5, batch_size: int=32, optimizer_param=None, self.run_ovr = False self.train_feature_num = None self.validate_feature_num = None + self.ovr = ovr + self.label_num = label_num + + if self.ovr: + if self.label_num is None or self.label_num < 2: + raise ValueError("label_num must be greater than 2 when ovr is True, but got {}".format(self.label_num)) # models & optimizer & schduler self.model = None @@ -235,17 +264,55 @@ def __init__(self, epochs: int=5, batch_size: int=32, optimizer_param=None, def _make_dataset(self, data: Data): return DictDataset(data) + + def _make_output_df(self, predict_rs, data: Data, threshold: float): + classes = [i for i in range(len(self.model.models))] + if len(classes) == 1: # binary: + classes = [0, 1] + task_type = BINARY if len(classes) == 2 else MULTI + out_df = std_output_df(task_type, predict_rs.predictions, predict_rs.label_ids, threshold=threshold, classes=classes) + out_df = add_ids(out_df, data.match_ids, data.sample_ids) + return out_df + + def _check_labels(self, label_set, has_validate=False): + + dataset_descrb = 'train dataset' if not has_validate else 'train and validate dataset' + if not self.ovr and len(label_set) > 2: + raise ValueError("please set ovr=True to enable multi-label classification, multiple labels found in {}: {}".format(dataset_descrb, label_set)) + if not self.ovr and len(label_set) == 2: + # 0, 1 is required + if 0 not in label_set or 1 not in label_set: + # ask for label 0, 1 when running binary classification + raise ValueError("when doing binary classification, lables must be 0, 1, but found in {}'s label set is {}".format(label_set, dataset_descrb)) + if self.ovr: + if max(label_set) > self.label_num - 1: + # make sure labels start from 0 and not the label indices not exceed the label num parameter + raise ValueError("when doing multi-label classification, labels must start from 0 and not exceed the label num parameter, \ + but {}'s label set is {}, while label num is {}".format(label_set, dataset_descrb, self.label_num)) def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = None) -> None: - self.train_data = Data.from_fate_dataframe(train_data) + # check data, must be fate Dataframe + assert isinstance(train_data, DataFrame), "train_data must be a fate DataFrame" + if validate_data is not None: + assert isinstance(validate_data, DataFrame), "validate_data must be a fate DataFrame" + + self.train_data: Data = Data.from_fate_dataframe(train_data) + if not self.train_data.has_label(): + raise RuntimeError("train data must have label column") self.train_feature_num = self.train_data.features.values.shape[1] + unique_label_set = set(self.train_data.labels.values.reshape(-1)) + if validate_data is not None: self.validate_data = Data.from_fate_dataframe(validate_data) + if not self.validate_data.has_label(): + raise RuntimeError("validate data must have label column") self.validate_feature_num = self.validate_data.features.values.shape[1] assert self.train_feature_num == self.validate_feature_num, "train and validate feature num not match: {} vs {}".format(self.train_feature_num, self.validate_feature_num) + unique_label_set = unique_label_set.union(set(self.validate_data.labels.values.reshape(-1))) + + self._check_labels(unique_label_set, validate_data is not None) - unique_label_set = set(self.train_data.labels.values.reshape(-1)) if validate_data is not None: unique_label_set = unique_label_set.union(set(self.validate_data.labels.values.reshape(-1))) logger.info("unique label set updated to: {}".format(unique_label_set)) @@ -271,7 +338,7 @@ def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = No logger.info('model initialized') logger.info('model parameters are {}'.format(list(self.model.parameters()))) else: - logger.info('model is loaded') + logger.info('model is loaded, warm start training') logger.info('model structure is {}'.format(self.model)) # initialize optimizer @@ -283,6 +350,7 @@ def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = No } self.optimizer.load_state_dict(optimizer_state_dict) logger.info('load warmstart optimizer state dict') + # training fed_arg = FedAVGArguments() train_arg = TrainingArguments(num_train_epochs=self.max_iter, @@ -290,6 +358,8 @@ def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = No self.trainer = FedAVGCLient(ctx, model=self.model, loss_fn=loss_fn, optimizer=self.optimizer, train_set=train_set, val_set=validate_set, training_args=train_arg, fed_args=fed_arg, data_collator=default_data_collator) self.trainer.train() + + logger.info('training finished') def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: @@ -305,18 +375,19 @@ def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: else: trainer = self.trainer predict_rs = trainer.predict(predict_set) - rs = {"predict_score": predict_rs.predictions, 'label': predict_rs.label_ids} - return rs + predict_out_df = self._make_output_df(predict_rs, self.predict_data, self.threshold) + return to_fate_df(ctx, self.predict_data.get_sample_id_name(), self.predict_data.get_match_id_name(), predict_out_df) def get_model(self) -> ModelIO: param = {} if self.model is not None: param['model'] = self.model.to_dict() if self.optimizer is not None: - param['optimizer'] = get_torch_bytes(self.optimizer.state_dict()) + param['optimizer'] = str(get_torch_bytes(self.optimizer.state_dict())) meta = {'batch_size': self.batch_size, 'max_iter': self.max_iter, 'threshold': self.threshold, - 'optimizer_param': self.optimizer_param, 'learning_rate_param': self.learning_rate_param, 'init_param': self.init_param} + 'optimizer_param': self.optimizer_param, 'learning_rate_param': self.learning_rate_param, 'init_param': self.init_param, 'ovr': self.ovr, + 'label_num': self.label_num} export_ = ModelIO(data=param, meta=meta) return export_ @@ -331,9 +402,11 @@ def from_model(self, model: ModelIO): if not 'model' in model_param: raise ValueError("param dict must have key 'model' that contains the model parameter and structure info") self.model = HomoLRModel.from_dict(model_param['model']) + if self.ovr: + assert len(self.model.models) == self.label_num, '' self.model.l1 = self.l1 if hasattr(model_param, 'optimizer'): - self.optimizer_state_dict = recover_torch_bytes(model_param['optimizer']) + self.optimizer_state_dict = recover_torch_bytes(bytes(model_param['optimizer'], 'utf-8')) self.loaded_meta = model['meta'] diff --git a/python/fate/ml/glm/homo_lr/test/local_test.py b/python/fate/ml/glm/homo_lr/test/local_test.py index 89da930ff1..d1661f7d22 100644 --- a/python/fate/ml/glm/homo_lr/test/local_test.py +++ b/python/fate/ml/glm/homo_lr/test/local_test.py @@ -30,20 +30,16 @@ df['sample_id'] = [i for i in range(len(df))] reader = PandasReader(sample_id_name='sample_id', match_id_name="id", label_name="y", dtype="object") +reader_2 = PandasReader(sample_id_name='sample_id', match_id_name="id", dtype="object") data = reader.to_frame(ctx, df) df = data.as_pd_df() +data_2 = reader_2.to_frame(ctx, df.drop(columns=['y'])) client = HomoLRClient(50, 800, learning_rate_scheduler=0.01) client.l2 = 0.01 client.l1 = 0.01 -client.fit(ctx, data) +client.fit(ctx, data, validate_data=data) export_model = client.get_model() pred = client.predict(ctx, data) +pred_2 = client.predict(ctx, data_2) -# print('load model and warm-starting') -# client_2 = HomoLRClient(1, batch_size=800, learning_rate_param=0.001) -# client_2.from_model(export_model) -# client_2.fit(ctx, data) - -# from fate.components.core.params._learning_rate import LRSchedulerParam -# from fate.components.core.params._optimizer import OptimizerParam diff --git a/python/fate/ml/utils/predict_format.py b/python/fate/ml/utils/predict_format.py new file mode 100644 index 0000000000..994c622078 --- /dev/null +++ b/python/fate/ml/utils/predict_format.py @@ -0,0 +1,84 @@ +import pandas as pd +from fate.arch.dataframe import PandasReader +import numpy as np + + +TRAIN_SET = 'train_set' +VALIDATE_SET = 'validate_set' +TEST_SET = 'test_set' +LABEL = "label" +PREDICT_LABEL = "predict_result" +PREDICT_SCORE = "predict_score" +PREDICT_DETAIL = "predict_detail" +TYPE = "type" + +# TASK TYPE +BINARY = 'binary' +MULTI = 'multi' +REGRESSION = 'regression' +OTHER = 'other' + + +def add_ids(df: pd.DataFrame, match_id: pd.DataFrame, sample_id:pd.DataFrame): + df = pd.concat([df, match_id, sample_id], axis=1) + return df + + +def add_dataset_type(df: pd.DataFrame, ds_type): + + assert ds_type in [TRAIN_SET, VALIDATE_SET, TEST_SET], 'ds_type must be one of {}, but got {}'.format([TRAIN_SET, VALIDATE_SET, TEST_SET], ds_type) + df[TYPE] = ds_type + return df + + +def to_fate_df(ctx, sample_id_name, match_id_name, result_df): + + if LABEL in result_df: + reader = PandasReader(sample_id_name=sample_id_name, match_id_name=match_id_name, label_name=LABEL, dtype="object") + else: + reader = PandasReader(sample_id_name=sample_id_name, match_id_name=match_id_name, dtype="object") + data = reader.to_frame(ctx, result_df) + return data + + +def std_output_df(task_type, pred: np.array, label: np.array=None, threshold=0.5, classes: list = None): + + assert task_type in [BINARY, MULTI, REGRESSION, OTHER], 'task_type must be one of {} as a std task, but got {}'.format([BINARY, MULTI, REGRESSION, OTHER], task_type) + + if task_type == BINARY: + if len(classes) == 2: + predict_score = pred + predict_result = (predict_score > threshold).astype(int) + predict_details = [{classes[0]: 1 - float(predict_score[i]), classes[1]: float(predict_score[i])} for i in range(len(predict_score))] + else: + raise ValueError('task_type is binary, but classes length is not 2: {}'.format(classes)) + + elif task_type == MULTI: + if len(classes) > 2: + predict_score = pred.max(axis=1) + predict_result = np.argmax(pred, axis=1) + predict_details = [{classes[j]: float(pred[i][j]) for j in range(len(classes))} for i in range(len(pred))] + else: + raise ValueError('task_type is multi, but classes length is not greater than 2: {}'.format(classes)) + + elif task_type == REGRESSION: + # regression task + predict_score = pred + predict_result = pred + predict_details = [{LABEL: float(pred[i])} for i in range(len(pred))] + + if label is None: + df = pd.DataFrame({ + PREDICT_SCORE: predict_score.flatten(), + PREDICT_LABEL: predict_result.flatten(), + PREDICT_DETAIL: predict_details + }) + else: + df = pd.DataFrame({ + PREDICT_SCORE: predict_score.flatten(), + PREDICT_LABEL: predict_result.flatten(), + LABEL: label.flatten(), + PREDICT_DETAIL: predict_details + }) + + return df \ No newline at end of file From 37a08099f334e66aa51338f1fd4cea7ca59de392 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Wed, 12 Jul 2023 15:41:11 +0800 Subject: [PATCH 24/61] fix guest gradient computation(#4659) Signed-off-by: Yu Wu --- python/fate/ml/glm/hetero/coordinated_linr/guest.py | 2 +- python/fate/ml/glm/hetero/coordinated_lr/guest.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/fate/ml/glm/hetero/coordinated_linr/guest.py b/python/fate/ml/glm/hetero/coordinated_linr/guest.py index fed5c8c50b..c875e5938d 100644 --- a/python/fate/ml/glm/hetero/coordinated_linr/guest.py +++ b/python/fate/ml/glm/hetero/coordinated_linr/guest.py @@ -155,7 +155,7 @@ def fit_model(self, ctx, train_data, validate_data=None): batch_ctx.arbiter.put(loss=loss) # gradient - g = 1 / h * X.T @ d + g = 1 / h * torch.matmul(X.T, d) g = self.optimizer.add_regular_to_grad(g, w, self.init_param.get("fit_intercept")) batch_ctx.arbiter.put("g_enc", g) g = batch_ctx.arbiter.get("g") diff --git a/python/fate/ml/glm/hetero/coordinated_lr/guest.py b/python/fate/ml/glm/hetero/coordinated_lr/guest.py index 9de4817835..b141360f70 100644 --- a/python/fate/ml/glm/hetero/coordinated_lr/guest.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/guest.py @@ -218,6 +218,7 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None): loss -= 0.5 / h * torch.matmul(Y.T, Xw_h) loss += 0.25 / h * torch.matmul(Xw.T, Xw_h) if weight: + logger.info(f"weight: {weight.tolist()}") d = d * weight batch_ctx.hosts.put(d=d) @@ -232,7 +233,7 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None): batch_ctx.arbiter.put(loss=loss) # gradient - g = 1 / h * X.T @ d + g = 1 / h * torch.matmul(X.T, d) g = self.optimizer.add_regular_to_grad(g, w, self.init_param.get("fit_intercept")) batch_ctx.arbiter.put("g_enc", g) g = batch_ctx.arbiter.get("g") From e39746cea4f7392dab95f49a3e02a061249eb467 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Wed, 12 Jul 2023 15:42:00 +0800 Subject: [PATCH 25/61] edit examples(#4659) Signed-off-by: Yu Wu --- examples/pipeline/test_lr_sid.py | 24 +++++++++++++----------- examples/pipeline/test_single_linr.py | 4 ++-- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/examples/pipeline/test_lr_sid.py b/examples/pipeline/test_lr_sid.py index 46d00e7f5f..3a61bd48ca 100644 --- a/examples/pipeline/test_lr_sid.py +++ b/examples/pipeline/test_lr_sid.py @@ -20,16 +20,18 @@ pipeline = FateFlowPipeline().set_roles(guest="9999", host="9998", arbiter="9998") intersect_0 = Intersection("intersect_0", method="raw") -intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest_sid", - namespace="experiment")) -intersect_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host_sid", - namespace="experiment")) +intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment_sid")) +intersect_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace="experiment_sid")) lr_0 = CoordinatedLR("lr_0", - epochs=10, + epochs=2, batch_size=100, - optimizer={"method": "sgd", "optimizer_params": {"lr": 0.01}, "alpha": 0.5}, + optimizer={"method": "sgd", "optimizer_params": {"lr": 0.01}}, init_param={"fit_intercept": True}, train_data=intersect_0.outputs["output_data"]) +lr_1 = CoordinatedLR("lr_1", test_data=intersect_0.outputs["output_data"], + input_model=lr_0.outputs["output_model"]) """lr_0.guest.component_setting(train_data=DataWarehouseChannel(name="breast_hetero_guest_sid", namespace="experiment")) @@ -56,15 +58,15 @@ print(f"evaluation metrics: ") print(pipeline.get_task_info("evaluation_0").get_output_metrics()) -pipeline.deploy([lr_0]) +pipeline.deploy([intersect_0, lr_0]) predict_pipeline = FateFlowPipeline() deployed_pipeline = pipeline.get_deployed_pipeline() -lr_0.guest.component_setting(test_data=DataWarehouseChannel(name="breast_hetero_guest_sid", - namespace="experiment")) -lr_0.hosts[0].component_setting(test_data=DataWarehouseChannel(name="breast_hetero_host_sid", - namespace="experiment")) +deployed_pipeline.intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment_sid")) +deployed_pipeline.intersect_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace="experiment_sid")) predict_pipeline.add_task(deployed_pipeline) predict_pipeline.compile() diff --git a/examples/pipeline/test_single_linr.py b/examples/pipeline/test_single_linr.py index bf04453eb8..e42bedebb7 100644 --- a/examples/pipeline/test_single_linr.py +++ b/examples/pipeline/test_single_linr.py @@ -28,7 +28,7 @@ input_model=feature_scale_0.outputs["output_model"])""" linr_0 = CoordinatedLinR("linr_0", - max_iter=10, + epochs=10, batch_size=-1, init_param={"fit_intercept": False}) @@ -38,7 +38,7 @@ namespace="experiment")) evaluation_0 = Evaluation("evaluation_0", - runtime_roles="guest", + runtime_roles=["guest"], input_data=linr_0.outputs["train_output_data"]) # pipeline.add_task(feature_scale_0) From 12e1afbb89a1cb9ffb6df2bc1250f0b1e066e60a Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Wed, 12 Jul 2023 17:51:02 +0800 Subject: [PATCH 26/61] fix selection param(#4661) Signed-off-by: Yu Wu --- .../components/hetero_feature_selection.py | 36 ++++++++++------ .../components/core/params/_filter_param.py | 8 ++-- .../hetero_feature_selection.py | 41 +++++++++++-------- 3 files changed, 50 insertions(+), 35 deletions(-) diff --git a/python/fate/components/components/hetero_feature_selection.py b/python/fate/components/components/hetero_feature_selection.py index ef6a71ccd7..543b812801 100644 --- a/python/fate/components/components/hetero_feature_selection.py +++ b/python/fate/components/components/hetero_feature_selection.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import List from fate.arch import Context from fate.components.core import GUEST, HOST, Role, cpn, params +logger = logging.getLogger(__name__) + @cpn.component(roles=[GUEST, HOST]) def hetero_feature_selection(ctx, role): @@ -60,12 +63,14 @@ def train( train_output_model: cpn.json_model_output(roles=[GUEST, HOST]) ): from fate.ml.feature_selection import HeteroSelectionModuleHost, HeteroSelectionModuleGuest + logger.info(f"start selection train") sub_ctx = ctx.sub_ctx("train") train_data = train_data.read() columns = train_data.schema.columns.to_list() if use_anonymous: + logger.debug(f"use anonymous columns") anonymous_columns = train_data.schema.anonymous_columns.to_list() if select_col is not None: select_col = [columns[anonymous_columns.index(col)] for col in select_col] @@ -82,13 +87,22 @@ def train( # temp code end input_models = [model.read() for model in input_models] if role.is_guest: - selection = HeteroSelectionModuleGuest(method, select_col, input_models, - iv_param, statistic_param, manual_param, - keep_one) + selection = HeteroSelectionModuleGuest(method=method, + select_col=select_col, + input_models=input_models, + iv_param=iv_param, + statistic_param=statistic_param, + manual_param=manual_param, + keep_one=keep_one) + elif role.is_host: - selection = HeteroSelectionModuleHost(method, select_col, input_models, - iv_param, statistic_param, manual_param, - keep_one) + selection = HeteroSelectionModuleHost(method=method, + select_col=select_col, + input_models=input_models, + iv_param=iv_param, + statistic_param=statistic_param, + manual_param=manual_param, + keep_one=keep_one) else: raise ValueError(f"role: {role} is not valid") selection.fit(sub_ctx, train_data) @@ -111,10 +125,9 @@ def predict( test_output_data: cpn.dataframe_output(roles=[GUEST, HOST]) ): from fate.ml.feature_selection import HeteroSelectionModuleHost, HeteroSelectionModuleGuest - + logger.info(f"start selection predict") sub_ctx = ctx.sub_ctx("predict") - with input_model as model_reader: - model = model_reader.read_model() + model = input_model.read() if role.is_guest: selection = HeteroSelectionModuleGuest.from_model(model) elif role.is_host: @@ -122,12 +135,9 @@ def predict( else: raise ValueError(f"role: {role} is not valid") - model_meta = model["meta_data"] - method = model_meta["method"] - selection.method = method test_data = test_data.read() output_data = test_data - if method is not None: + if selection.method is not None: output_data = selection.transform(sub_ctx, test_data) test_output_data.write(output_data) diff --git a/python/fate/components/core/params/_filter_param.py b/python/fate/components/core/params/_filter_param.py index 4b6435dd4d..4983a4b9c9 100644 --- a/python/fate/components/core/params/_filter_param.py +++ b/python/fate/components/core/params/_filter_param.py @@ -76,15 +76,15 @@ class StatisticFilterParam(StandardFilterParam, Parameter): class ManualFilterParam(pydantic.BaseModel, Parameter): keep_col: List[str] = [] - left_out_col: List[str] = [] + filter_out_col: List[str] = [] @pydantic.root_validator(pre=False) def no_intersection(cls, values): - left_out_col = values.get('left_out_col', []) + filter_out_col = values.get('filter_out_col', []) keep_col = values.get('keep_col', []) - intersection = set(left_out_col).intersection(set(keep_col)) + intersection = set(filter_out_col).intersection(set(keep_col)) if intersection: - raise ValueError(f"`keep_col` and `left_out_col` share common elements: {intersection}") + raise ValueError(f"`keep_col` and `filter_out_col` share common elements: {intersection}") return values diff --git a/python/fate/ml/feature_selection/hetero_feature_selection.py b/python/fate/ml/feature_selection/hetero_feature_selection.py index 2a8558c89a..43fd12d963 100644 --- a/python/fate/ml/feature_selection/hetero_feature_selection.py +++ b/python/fate/ml/feature_selection/hetero_feature_selection.py @@ -44,13 +44,15 @@ def __init__(self, method=None, select_col=None, input_models=None, self._inner_method = [] self._selection_obj = [] - isometric_model_dict = {} - for model in input_models: - model_type = model["meta"].get("model_type") - if model_type is None: - raise ValueError(f"Missing 'model_type' in input model") - isometric_model_dict[model_type] = model - self.isometric_model_dict = isometric_model_dict + self.isometric_model_dict = None + if input_models: + isometric_model_dict = {} + for model in input_models: + model_type = model["meta"].get("model_type") + if model_type is None: + raise ValueError(f"Missing 'model_type' in input model") + isometric_model_dict[model_type] = model + self.isometric_model_dict = isometric_model_dict def fit(self, ctx: Context, train_data, validate_data=None) -> None: logger.info(f"isometric_model_dict: {self.isometric_model_dict}") @@ -158,13 +160,15 @@ def __init__(self, method=None, select_col=None, input_models=None, self._inner_method = [] self._selection_obj = [] - isometric_model_dict = {} - for model in input_models: - model_type = model["meta"].get("model_type") - if model_type is None: - raise ValueError(f"Missing 'model_type' in input model") - isometric_model_dict[model_type] = model - self.isometric_model_dict = isometric_model_dict + self.isometric_model_dict = None + if input_models: + isometric_model_dict = {} + for model in input_models: + model_type = model["meta"].get("model_type") + if model_type is None: + raise ValueError(f"Missing 'model_type' in input model") + isometric_model_dict[model_type] = model + self.isometric_model_dict = isometric_model_dict def fit(self, ctx: Context, train_data, validate_data=None) -> None: if self.select_col is None: @@ -273,10 +277,11 @@ def __init__(self, method, param=None, header=None, model=None, keep_one=True): self.keep_one = keep_one self._header = header self._prev_selected_mask = None + self._selected_mask = None if header is None: - self._selected_mask = None + self._prev_selected_mask = None else: - self._selected_mask = pd.Series(np.ones(len(header)), dtype=bool, index=header) + self._prev_selected_mask = pd.Series(np.ones(len(header)), dtype=bool, index=header) def set_selected_mask(self, mask): self._selected_mask = mask @@ -330,7 +335,7 @@ def to_model(self): def restore(self, model): self.method = model["method"] self.keep_one = model["keep_one"] - self._selected_mask = pd.Series(["selected_mask"], dtype=bool) + self._selected_mask = pd.Series(model["selected_mask"], dtype=bool) class StandardSelection(Module): @@ -570,7 +575,7 @@ def to_model(self): def restore(self, model): self.method = model["method"] self.keep_one = model["keep_one"] - self._selected_mask = pd.Series(["selected_mask"], dtype=bool) + self._selected_mask = pd.Series(model["selected_mask"], dtype=bool) self._all_selected_mask = pd.DataFrame(model["all_selected_mask"], dtype=bool) self._all_metrics = pd.DataFrame(model["all_metrics"]) self._host_selected_mask = {k: pd.Series(v, dtype=bool) for k, v in model["host_selected_mask"].items()} From 4c8bc248ee31a6a399f74b7bdee3d4f9a7a54d5d Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Wed, 12 Jul 2023 17:53:06 +0800 Subject: [PATCH 27/61] edit selection example(#4661) Signed-off-by: Yu Wu --- examples/pipeline/test_selection.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/pipeline/test_selection.py b/examples/pipeline/test_selection.py index 8bc347d296..b0a488961b 100644 --- a/examples/pipeline/test_selection.py +++ b/examples/pipeline/test_selection.py @@ -47,9 +47,10 @@ selection_0 = HeteroFeatureSelection("selection_0", train_data=feature_scale_0.outputs["train_output_data"], - method=["statistics"], + method=["manual", "statistics"], input_models=[statistics_0.outputs["output_model"]], - statistic_param={"metrics": ["mean", "max", "kurtosis", "skewness"]}) + statistic_param={"metrics": ["mean", "max", "kurtosis", "skewness"]}, + manual_param={"filter_out_col": ["x0", "x3"]}) pipeline.add_task(feature_scale_0) pipeline.add_task(statistics_0) From 4eba77a986b8ccdafa2e04acc5a7e833387093e8 Mon Sep 17 00:00:00 2001 From: weiwee Date: Wed, 12 Jul 2023 19:50:50 +0800 Subject: [PATCH 28/61] fix parameter type check Signed-off-by: weiwee --- python/fate/components/core/component_desc/_parameter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/fate/components/core/component_desc/_parameter.py b/python/fate/components/core/component_desc/_parameter.py index 12803c0ab4..7300292db9 100644 --- a/python/fate/components/core/component_desc/_parameter.py +++ b/python/fate/components/core/component_desc/_parameter.py @@ -24,7 +24,8 @@ def merge(self, p: "ParameterDescribe"): raise ComponentParameterDuplicateError( f"parameter {p.name} declare multiple times with different optional: `{self.optional}` vs `{p.optional}`" ) - if str(self.type) != str(p.type) or self.type.__dict__ != p.type.__dict__: + # if str(self.type) != str(p.type) or self.type.__dict__ != p.type.__dict__: + if str(self.type) != str(p.type): raise ComponentParameterDuplicateError( f"parameter {p.name} declare multiple times with different type: `{self.type}({self.type.__dict__})` vs `{self.type}({self.type.__dict__})`" ) From 709166af6e3434c303e521ca8122aee96201fee7 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Wed, 12 Jul 2023 21:02:53 +0800 Subject: [PATCH 29/61] update toy example Signed-off-by: mgqa34 --- python/fate/components/components/__init__.py | 5 ++ .../fate/components/components/toy_example.py | 49 +++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 python/fate/components/components/toy_example.py diff --git a/python/fate/components/components/__init__.py b/python/fate/components/components/__init__.py index 54c6deaf73..216577bda1 100644 --- a/python/fate/components/components/__init__.py +++ b/python/fate/components/components/__init__.py @@ -102,6 +102,11 @@ def statistics(self): return statistics + @_lazy_cpn + def toy_example(self): + from .toy_example import toy_example + return toy_example + @_lazy_cpn def dataframe_io_test(self): from .dataframe_io_test import dataframe_io_test diff --git a/python/fate/components/components/toy_example.py b/python/fate/components/components/toy_example.py new file mode 100644 index 0000000000..cdb4faf5de --- /dev/null +++ b/python/fate/components/components/toy_example.py @@ -0,0 +1,49 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union + +import pandas as pd +from fate.arch import Context +from fate.arch.dataframe import PandasReader +from fate.components.core import GUEST, HOST, Role, cpn, params + + +@cpn.component(roles=[GUEST, HOST]) +def toy_example( + ctx: Context, + role: Role, + output_data: cpn.dataframe_output(roles=[GUEST, HOST]), + json_model_output: cpn.json_model_output(roles=[GUEST, HOST]), + data_num: cpn.parameter(type=params.conint(gt=1), desc="data_num", optional=False), + partition: cpn.parameter(type=params.conint(gt=1), desc="data_partition", optional=False), +): + pd_df = pd.DataFrame([[str(i), str(i), i] for i in range(data_num)], columns=["sample_id", "match_id", "x0"]) + reader = PandasReader(sample_id_name="sample_id", match_id_name="match_id", dtype="float64", partition=partition) + df = reader.to_frame(ctx, pd_df) + + if role == "guest": + ctx.hosts.put("guest_index", df.get_indexer(target="sample_id")) + host_indexes = ctx.hosts[0].get("host_index") + final_df = df.loc(host_indexes, preserve_order=True) + else: + guest_indexes = ctx.guest.get("guest_index") + final_df = df.loc(guest_indexes) + ctx.guest.put("host_index", final_df.get_indexer(target="sample_id")) + + assert final_df.shape[0] == data_num, f"data num should be {data_num} instead of {final_df}" + + output_data.write(final_df) + + json_model_output.write({"test_role": role}) From a8327f5f47ec7b8150f6ef976634f191d8db396f Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Thu, 13 Jul 2023 10:47:17 +0800 Subject: [PATCH 30/61] coordinated lr add cv(#4659) Signed-off-by: Yu Wu --- examples/pipeline/test_lr_sid_cv.py | 52 +++++++ .../components/components/coordinated_lr.py | 141 +++++++++++++----- .../fate/components/core/params/__init__.py | 1 + .../fate/components/core/params/_cv_param.py | 14 ++ .../ml/glm/hetero/coordinated_lr/guest.py | 6 +- python/fate/ml/utils/_model_param.py | 8 +- 6 files changed, 184 insertions(+), 38 deletions(-) create mode 100644 examples/pipeline/test_lr_sid_cv.py create mode 100644 python/fate/components/core/params/_cv_param.py diff --git a/examples/pipeline/test_lr_sid_cv.py b/examples/pipeline/test_lr_sid_cv.py new file mode 100644 index 0000000000..434a7c0024 --- /dev/null +++ b/examples/pipeline/test_lr_sid_cv.py @@ -0,0 +1,52 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLR, Intersection +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel + +pipeline = FateFlowPipeline().set_roles(guest="9999", host="9998", arbiter="9998") + +intersect_0 = Intersection("intersect_0", method="raw") +intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment_sid_0713")) +intersect_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace="experiment_sid_0713")) +lr_0 = CoordinatedLR("lr_0", + epochs=2, + batch_size=100, + optimizer={"method": "sgd", "optimizer_params": {"lr": 0.01}}, + init_param={"fit_intercept": True}, + cv_data=intersect_0.outputs["output_data"], + cv_param={"n_splits": 3}) + +"""lr_0.guest.component_setting(train_data=DataWarehouseChannel(name="breast_hetero_guest_sid", + namespace="experiment")) +lr_0.hosts[0].component_setting(train_data=DataWarehouseChannel(name="breast_hetero_host_sid", + namespace="experiment"))""" + +evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + input_data=lr_0.outputs["train_output_data"]) + +# pipeline.add_task(feature_scale_0) +# pipeline.add_task(feature_scale_1) +pipeline.add_task(intersect_0) +pipeline.add_task(lr_0) +pipeline.add_task(evaluation_0) +# pipeline.add_task(hetero_feature_binning_0) +pipeline.compile() +print(pipeline.get_dag()) +pipeline.fit() diff --git a/python/fate/components/components/coordinated_lr.py b/python/fate/components/components/coordinated_lr.py index 9d1c448535..6194e7aa6b 100644 --- a/python/fate/components/components/coordinated_lr.py +++ b/python/fate/components/components/coordinated_lr.py @@ -17,7 +17,9 @@ import logging from fate.arch import Context +from fate.arch.dataframe import DataFrame from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params +from fate.ml.glm import CoordinatedLRModuleGuest, CoordinatedLRModuleHost, CoordinatedLRModuleArbiter logger = logging.getLogger(__name__) @@ -120,35 +122,114 @@ def predict( predict_host(ctx, input_model, test_data, test_output_data) -"""@coordinated_lr.cross_validation() +@coordinated_lr.cross_validation() def cross_validation( - ctx: Context, - role: Role, - data: cpn.dataframe_input(roles=[GUEST, HOST]), - num_fold: cpn.parameter(type=params.conint(ge=2), desc="num cross validation fold"), - learning_rate: cpn.parameter(type=params.learning_rate_param(), default=0.1, desc="learning rate"), - epochs: cpn.parameter(type=params.conint(gt=0), default=100, desc="max iteration num"), - batch_size: cpn.parameter( - type=params.conint(gt=0), default=100, desc="batch size, value less or equals to 0 means full batch" - ), + ctx: Context, + role: Role, + cv_data: cpn.dataframe_input(roles=[GUEST, HOST]), + learning_rate_scheduler: cpn.parameter( + type=params.lr_scheduler_param(), + default=params.LRSchedulerParam(method="linear", scheduler_params={"start_factor": 1.0}), + desc="learning rate scheduler, " + "select method from {'step', 'linear', 'constant'}" + "for list of configurable arguments, " + "refer to torch.optim.lr_scheduler", + ), + epochs: cpn.parameter(type=params.conint(gt=0), default=20, desc="max iteration num"), + batch_size: cpn.parameter( + type=params.conint(ge=-1), default=100, desc="batch size, " "value less or equals to 0 means full batch" + ), + optimizer: cpn.parameter( + type=params.optimizer_param(), + default=params.OptimizerParam( + method="sgd", penalty="l2", alpha=1.0, optimizer_params={"lr": 1e-2, "weight_decay": 0} + ), + ), + tol: cpn.parameter(type=params.confloat(ge=0), default=1e-4), + early_stop: cpn.parameter( + type=params.string_choice(["weight_diff", "diff", "abs"]), + default="diff", + desc="early stopping criterion, choose from {weight_diff, diff, abs, val_metrics}", + ), + init_param: cpn.parameter( + type=params.init_param(), + default=params.InitParam(method="zeros", fit_intercept=True), + desc="Model param init setting.", + ), + threshold: cpn.parameter( + type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data" + ), + cv_param: cpn.parameter(type=params.cv_param(), + default=params.CVParam(n_splits=5, shuffle=False, random_state=None), + desc="cross validation param"), + metrics: cpn.parameter(type=params.metrics_param(), default=["auc"]), + cv_output_datas: cpn.dataframe_outputs(roles=[GUEST, HOST]), ): - cv_ctx = ctx.on_cross_validations - data = ctx.reader(data).read_dataframe() - # TODO: split data - for i, fold_ctx in cv_ctx.ctxs_range(num_fold): + # temp code start + optimizer = optimizer.dict() + learning_rate_scheduler = learning_rate_scheduler.dict() + init_param = init_param.dict() + # temp code end + i = 0 + if role.is_arbiter: + for fold_ctx, _ in ctx.on_cross_validations.ctxs_zip(zip(range(cv_param.n_splits))): + logger.info(f"enter fold {i}") + module = CoordinatedLRModuleArbiter( + epochs=epochs, + early_stop=early_stop, + tol=tol, + batch_size=batch_size, + optimizer_param=optimizer, + learning_rate_param=learning_rate_scheduler, + ) + module.fit(fold_ctx) + + from fate.arch.dataframe import KFold + kf = KFold(ctx, role=role, n_splits=cv_param.n_splits, shuffle=cv_param.shuffle, random_state=cv_param.random_state) + i = 0 + for fold_ctx, (train_data, validate_data) in ctx.on_cross_validations.ctxs_zip(kf.split(cv_data.read())): + logger.info(f"enter fold {i}") + logger.info(f"train_data schema: {train_data.schema}, columns: {train_data.schema.columns}") if role.is_guest: - from fate.ml.glm.coordinated_lr import CoordinatedLRModuleGuest - - module = CoordinatedLRModuleGuest(epochs=epochs, learning_rate=learning_rate, batch_size=batch_size) - train_data, validate_data = split_dataframe(data, num_fold, i) - module.fit(fold_ctx, train_data) - predicted = module.predict(fold_ctx, validate_data) - evaluation = evaluate(predicted) + module = CoordinatedLRModuleGuest( + epochs=epochs, + batch_size=batch_size, + optimizer_param=optimizer, + learning_rate_param=learning_rate_scheduler, + init_param=init_param, + threshold=threshold, + ) + module.fit(fold_ctx, train_data, validate_data) + sub_ctx = fold_ctx.sub_ctx("predict_train") + predict_score = module.predict(sub_ctx, train_data) + train_predict_result = transform_to_predict_result( + train_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, + data_type="train" + ) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + predict_score = module.predict(sub_ctx, validate_data) + validate_predict_result = transform_to_predict_result( + validate_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, + data_type="predict" + ) + predict_result = DataFrame.vstack([train_predict_result, validate_predict_result]) + next(cv_output_datas).write(df=predict_result) + + # evaluation = evaluate(predicted) elif role.is_host: - ... - elif role.is_arbiter: - ... -""" + module = CoordinatedLRModuleHost( + epochs=epochs, + batch_size=batch_size, + optimizer_param=optimizer, + learning_rate_param=learning_rate_scheduler, + init_param=init_param, + ) + module.fit(fold_ctx, train_data, validate_data) + sub_ctx = fold_ctx.sub_ctx("predict_train") + module.predict(sub_ctx, train_data) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + module.predict(sub_ctx, validate_data) + i += 1 def train_guest( @@ -168,8 +249,6 @@ def train_guest( # optimizer = optimizer_factory(optimizer_param) logger.info(f"coordinated lr guest start train") - from fate.ml.glm import CoordinatedLRModuleGuest - sub_ctx = ctx.sub_ctx("train") module = CoordinatedLRModuleGuest( epochs=epochs, @@ -222,8 +301,6 @@ def train_host( init_param, ): logger.info(f"coordinated lr host start train") - from fate.ml.glm import CoordinatedLRModuleHost - sub_ctx = ctx.sub_ctx("train") module = CoordinatedLRModuleHost( epochs=epochs, @@ -249,8 +326,6 @@ def train_host( def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_scheduler, output_model): logger.info(f"coordinated lr arbiter start train") - from fate.ml.glm import CoordinatedLRModuleArbiter - sub_ctx = ctx.sub_ctx("train") module = CoordinatedLRModuleArbiter( epochs=epochs, @@ -267,8 +342,6 @@ def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, lea def predict_guest(ctx, input_model, test_data, test_output_data): logger.info(f"coordinated lr guest start predict") - from fate.ml.glm import CoordinatedLRModuleGuest - sub_ctx = ctx.sub_ctx("predict") model = input_model.read() module = CoordinatedLRModuleGuest.from_model(model) @@ -284,8 +357,6 @@ def predict_guest(ctx, input_model, test_data, test_output_data): def predict_host(ctx, input_model, test_data, test_output_data): logger.info(f"coordinated lr host start predict") - from fate.ml.glm import CoordinatedLRModuleHost - sub_ctx = ctx.sub_ctx("predict") model = input_model.read() module = CoordinatedLRModuleHost.from_model(model) diff --git a/python/fate/components/core/params/__init__.py b/python/fate/components/core/params/__init__.py index d712f4553f..915fff3331 100644 --- a/python/fate/components/core/params/__init__.py +++ b/python/fate/components/core/params/__init__.py @@ -15,6 +15,7 @@ # from ._cipher import CipherParamType, PaillierCipherParam +from ._cv_param import cv_param, CVParam from ._fields import confloat, conint, jsonschema, parse, string_choice, Parameter from ._init_param import InitParam, init_param from ._init_param import InitParam, init_param diff --git a/python/fate/components/core/params/_cv_param.py b/python/fate/components/core/params/_cv_param.py new file mode 100644 index 0000000000..92b4c093b4 --- /dev/null +++ b/python/fate/components/core/params/_cv_param.py @@ -0,0 +1,14 @@ +import pydantic + +from ._fields import conint + + +class CVParam(pydantic.BaseModel): + n_splits: conint(gt=1) + shuffle: bool = False + random_state: int = None + + +def cv_param(): + namespace = {} + return type("CVParam", (CVParam,), namespace) diff --git a/python/fate/ml/glm/hetero/coordinated_lr/guest.py b/python/fate/ml/glm/hetero/coordinated_lr/guest.py index b141360f70..99e5586581 100644 --- a/python/fate/ml/glm/hetero/coordinated_lr/guest.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/guest.py @@ -175,8 +175,10 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None): loss = log2 - (1/N)*0.5*∑ywx + (1/N)*0.125*[∑(Wg*Xg)^2 + ∑(Wh*Xh)^2 + 2 * ∑(Wg*Xg * Wh*Xh)] """ coef_count = train_data.shape[1] + logger.info(f"coef count: {coef_count}") if self.init_param.get("fit_intercept"): - train_data["intercept"] = 1.0 + if "intercept" not in train_data.schema.columns: + train_data["intercept"] = 1.0 w = self.w if w is None: @@ -203,7 +205,7 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None): weight = batch_data.weight h = X.shape[0] # logger.info(f"h: {h}") - + logger.info(f"w: {w.detach()}, X shape: {X.shape}") Xw = torch.matmul(X, w.detach()) d = 0.25 * Xw - 0.5 * Y loss = 0.125 / h * torch.matmul(Xw.T, Xw) - 0.5 / h * torch.matmul(Xw.T, Y) diff --git a/python/fate/ml/utils/_model_param.py b/python/fate/ml/utils/_model_param.py index 2d2840a417..2f77f0912b 100644 --- a/python/fate/ml/utils/_model_param.py +++ b/python/fate/ml/utils/_model_param.py @@ -13,15 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + import torch +logger = logging.getLogger(__name__) + def initialize_param(coef_len, **kwargs): param_len = coef_len method = kwargs["method"] fit_intercept = kwargs["fit_intercept"] + logger.info(f"kwargs: {kwargs}") if fit_intercept: - param_len = coef_len + 1 + param_len = param_len + 1 + logger.info(f"intercept added: param len {param_len}") if method == 'zeros': return torch.zeros((param_len, 1), requires_grad=True) elif method == 'ones': From 39092445a039b9ab24398b78c03a6e3eee6a318e Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Thu, 13 Jul 2023 11:10:32 +0800 Subject: [PATCH 31/61] dataframe: fix loc in-place dm bug, add quantile interface Signed-off-by: mgqa34 --- python/fate/arch/dataframe/_dataframe.py | 8 +-- python/fate/arch/dataframe/ops/_quantile.py | 58 +++++++++++++++++++++ 2 files changed, 62 insertions(+), 4 deletions(-) create mode 100644 python/fate/arch/dataframe/ops/_quantile.py diff --git a/python/fate/arch/dataframe/_dataframe.py b/python/fate/arch/dataframe/_dataframe.py index 4f8ada5eb0..54649051e6 100644 --- a/python/fate/arch/dataframe/_dataframe.py +++ b/python/fate/arch/dataframe/_dataframe.py @@ -272,10 +272,10 @@ def describe(self, ddof=1, unbiased=False): def quantile( self, q, - axis=0, - method="quantile", + relative_error: float = 1e-4 ): - ... + from .ops._quantile import quantile + return quantile(self, q, relative_error) def __add__(self, other: Union[int, float, list, "np.ndarray", "DataFrame", "pd.Series"]) -> "DataFrame": return self.__arithmetic_operate(operator.add, other) @@ -519,7 +519,7 @@ def _merge_list(lhs, rhs): block_table = transform_list_block_to_frame_block(block_table, self._data_manager) partition_order_mappings = get_partition_order_mappings(block_table) - return DataFrame(self._ctx, block_table, partition_order_mappings, self._data_manager) + return DataFrame(self._ctx, block_table, partition_order_mappings, self._data_manager.duplicate()) def iloc(self, indexes): ... diff --git a/python/fate/arch/dataframe/ops/_quantile.py b/python/fate/arch/dataframe/ops/_quantile.py new file mode 100644 index 0000000000..e6063a7ff0 --- /dev/null +++ b/python/fate/arch/dataframe/ops/_quantile.py @@ -0,0 +1,58 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import functools +import pandas as pd +from .._dataframe import DataFrame +from fate.arch.tensor.inside import GKSummary + + +def quantile(df: DataFrame, q, relative_error: float): + if isinstance(q, float): + q = [q] + elif not isinstance(q, list): + q = list(q) + + data_manager = df.data_manager + column_names = data_manager.infer_operable_field_names() + blocks_loc = [data_manager.loc_block(name) for name in column_names] + + def _mapper(blocks, columns_loc=None, error=None): + column_size = len(columns_loc) + gk_summary_obj_list = [GKSummary(error) for _ in range(column_size)] + + for idx, (bid, offset) in enumerate(columns_loc): + gk_summary_obj_list[idx] += blocks[bid][:, offset] + + return gk_summary_obj_list + + def _reducer(l_gk_summary_obj_list, r_gk_summary_obj_list): + rets = [] + for l_gk_summary_obj, r_gk_summary_obj in zip(l_gk_summary_obj_list, r_gk_summary_obj_list): + rets.append(l_gk_summary_obj + r_gk_summary_obj) + + return rets + + gk_summary_func = functools.partial(_mapper, columns_loc=blocks_loc, error=relative_error) + ret_gk_summary_obj_list = df.block_table.mapValues(gk_summary_func).reduce(_reducer) + + quantile_rets = dict() + for column_name, gk_summary_obj in zip(column_names, ret_gk_summary_obj_list): + query_ret = gk_summary_obj.queries(q) + quantile_rets[column_name] = query_ret + + quantile_df = pd.DataFrame(quantile_rets, index=q) + + return quantile_df From d6677ca2c5f099990a6b1ab92481208011c5ad70 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Thu, 13 Jul 2023 12:54:31 +0800 Subject: [PATCH 32/61] coordinated lr & linr add cv(#4659) set batch size default to None add linr cv examples Signed-off-by: Yu Wu --- examples/pipeline/test_linr_sid_cv.py | 38 ++++++ examples/pipeline/test_lr_sid_cv.py | 18 +-- .../fate/arch/dataframe/utils/_dataloader.py | 2 +- .../components/components/coordinated_linr.py | 129 +++++++++++++++--- .../components/components/coordinated_lr.py | 14 +- .../ml/glm/hetero/coordinated_linr/guest.py | 6 +- .../ml/glm/hetero/coordinated_lr/guest.py | 7 +- python/fate/ml/utils/_model_param.py | 6 - 8 files changed, 166 insertions(+), 54 deletions(-) create mode 100644 examples/pipeline/test_linr_sid_cv.py diff --git a/examples/pipeline/test_linr_sid_cv.py b/examples/pipeline/test_linr_sid_cv.py new file mode 100644 index 0000000000..a7e7d3a1e2 --- /dev/null +++ b/examples/pipeline/test_linr_sid_cv.py @@ -0,0 +1,38 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLinR, Intersection +from fate_client.pipeline.interface import DataWarehouseChannel + +pipeline = FateFlowPipeline().set_roles(guest="9999", host="9998", arbiter="9998") + +intersect_0 = Intersection("intersect_0", method="raw") +intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", + namespace="experiment_sid")) +intersect_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", + namespace="experiment_sid")) +linr_0 = CoordinatedLinR("linr_0", + epochs=2, + batch_size=100, + optimizer={"method": "sgd", "optimizer_params": {"lr": 0.2}}, + init_param={"fit_intercept": True}, + cv_data=intersect_0.outputs["output_data"], + cv_param={"n_splits": 3}) + +pipeline.add_task(intersect_0) +pipeline.add_task(linr_0) +pipeline.compile() +print(pipeline.get_dag()) +pipeline.fit() diff --git a/examples/pipeline/test_lr_sid_cv.py b/examples/pipeline/test_lr_sid_cv.py index 434a7c0024..2f136a1d60 100644 --- a/examples/pipeline/test_lr_sid_cv.py +++ b/examples/pipeline/test_lr_sid_cv.py @@ -14,16 +14,15 @@ # limitations under the License. from fate_client.pipeline import FateFlowPipeline from fate_client.pipeline.components.fate import CoordinatedLR, Intersection -from fate_client.pipeline.components.fate import Evaluation from fate_client.pipeline.interface import DataWarehouseChannel pipeline = FateFlowPipeline().set_roles(guest="9999", host="9998", arbiter="9998") intersect_0 = Intersection("intersect_0", method="raw") intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace="experiment_sid_0713")) + namespace="experiment_sid")) intersect_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace="experiment_sid_0713")) + namespace="experiment_sid")) lr_0 = CoordinatedLR("lr_0", epochs=2, batch_size=100, @@ -32,21 +31,8 @@ cv_data=intersect_0.outputs["output_data"], cv_param={"n_splits": 3}) -"""lr_0.guest.component_setting(train_data=DataWarehouseChannel(name="breast_hetero_guest_sid", - namespace="experiment")) -lr_0.hosts[0].component_setting(train_data=DataWarehouseChannel(name="breast_hetero_host_sid", - namespace="experiment"))""" - -evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], - input_data=lr_0.outputs["train_output_data"]) - -# pipeline.add_task(feature_scale_0) -# pipeline.add_task(feature_scale_1) pipeline.add_task(intersect_0) pipeline.add_task(lr_0) -pipeline.add_task(evaluation_0) -# pipeline.add_task(hetero_feature_binning_0) pipeline.compile() print(pipeline.get_dag()) pipeline.fit() diff --git a/python/fate/arch/dataframe/utils/_dataloader.py b/python/fate/arch/dataframe/utils/_dataloader.py index 119e0b18bf..6fae9937f8 100644 --- a/python/fate/arch/dataframe/utils/_dataloader.py +++ b/python/fate/arch/dataframe/utils/_dataloader.py @@ -33,7 +33,7 @@ def __init__( self._dataset = dataset self._batch_size = batch_size if dataset: - if batch_size == -1: + if batch_size is None: self._batch_size = len(dataset) else: self._batch_size = min(batch_size, len(dataset)) diff --git a/python/fate/components/components/coordinated_linr.py b/python/fate/components/components/coordinated_linr.py index fbdc0ebcf0..6b71a548c0 100644 --- a/python/fate/components/components/coordinated_linr.py +++ b/python/fate/components/components/coordinated_linr.py @@ -19,6 +19,7 @@ from fate.arch import Context from fate.arch.dataframe import DataFrame from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params +from fate.ml.glm import CoordinatedLinRModuleArbiter, CoordinatedLinRModuleGuest, CoordinatedLinRModuleHost logger = logging.getLogger(__name__) @@ -43,9 +44,10 @@ def train( "refer to torch.optim.lr_scheduler"), epochs: cpn.parameter(type=params.conint(gt=0), default=20, desc="max iteration num"), - batch_size: cpn.parameter(type=params.conint(ge=-1), default=100, - desc="batch size, " - "value less or equals to 0 means full batch"), + batch_size: cpn.parameter( + type=params.conint(ge=10), + default=None, desc="batch size, None means full batch, otherwise should be no less than 10, default None" + ), optimizer: cpn.parameter(type=params.optimizer_param(), default=params.OptimizerParam(method="sgd", penalty='l2', alpha=1.0, optimizer_params={"lr": 1e-2, "weight_decay": 0})), @@ -92,11 +94,115 @@ def predict( predict_host(ctx, input_model, test_data, test_output_data) +@coordinated_linr.cross_validation() +def cross_validation( + ctx: Context, + role: Role, + cv_data: cpn.dataframe_input(roles=[GUEST, HOST]), + learning_rate_scheduler: cpn.parameter( + type=params.lr_scheduler_param(), + default=params.LRSchedulerParam(method="linear", scheduler_params={"start_factor": 1.0}), + desc="learning rate scheduler, " + "select method from {'step', 'linear', 'constant'}" + "for list of configurable arguments, " + "refer to torch.optim.lr_scheduler", + ), + epochs: cpn.parameter(type=params.conint(gt=0), default=20, desc="max iteration num"), + batch_size: cpn.parameter( + type=params.conint(ge=10), + default=None, desc="batch size, None means full batch, otherwise should be no less than 10, default None" + ), + optimizer: cpn.parameter( + type=params.optimizer_param(), + default=params.OptimizerParam( + method="sgd", penalty="l2", alpha=1.0, optimizer_params={"lr": 1e-2, "weight_decay": 0} + ), + ), + tol: cpn.parameter(type=params.confloat(ge=0), default=1e-4), + early_stop: cpn.parameter( + type=params.string_choice(["weight_diff", "diff", "abs"]), + default="diff", + desc="early stopping criterion, choose from {weight_diff, diff, abs, val_metrics}", + ), + init_param: cpn.parameter( + type=params.init_param(), + default=params.InitParam(method="zeros", fit_intercept=True), + desc="Model param init setting.", + ), + cv_param: cpn.parameter(type=params.cv_param(), + default=params.CVParam(n_splits=5, shuffle=False, random_state=None), + desc="cross validation param"), + metrics: cpn.parameter(type=params.metrics_param(), default=["mse"]), + cv_output_datas: cpn.dataframe_outputs(roles=[GUEST, HOST]), +): + # temp code start + optimizer = optimizer.dict() + learning_rate_scheduler = learning_rate_scheduler.dict() + init_param = init_param.dict() + # temp code end + if role.is_arbiter: + i = 0 + for fold_ctx, _ in ctx.on_cross_validations.ctxs_zip(zip(range(cv_param.n_splits))): + logger.info(f"enter fold {i}") + module = CoordinatedLinRModuleArbiter( + epochs=epochs, + early_stop=early_stop, + tol=tol, + batch_size=batch_size, + optimizer_param=optimizer, + learning_rate_param=learning_rate_scheduler, + ) + module.fit(fold_ctx) + i += 1 + return + + from fate.arch.dataframe import KFold + kf = KFold(ctx, role=role, n_splits=cv_param.n_splits, shuffle=cv_param.shuffle, random_state=cv_param.random_state) + i = 0 + for fold_ctx, (train_data, validate_data) in ctx.on_cross_validations.ctxs_zip(kf.split(cv_data.read())): + logger.info(f"enter fold {i}") + if role.is_guest: + module = CoordinatedLinRModuleGuest( + epochs=epochs, + batch_size=batch_size, + optimizer_param=optimizer, + learning_rate_param=learning_rate_scheduler, + init_param=init_param + ) + module.fit(fold_ctx, train_data, validate_data) + sub_ctx = fold_ctx.sub_ctx("predict_train") + predict_score = module.predict(sub_ctx, train_data) + train_predict_result = transform_to_predict_result( + train_data, predict_score, data_type="train" + ) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + predict_score = module.predict(sub_ctx, validate_data) + validate_predict_result = transform_to_predict_result( + validate_data, predict_score, data_type="predict" + ) + predict_result = DataFrame.vstack([train_predict_result, validate_predict_result]) + next(cv_output_datas).write(df=predict_result) + + # evaluation = evaluate(predicted) + elif role.is_host: + module = CoordinatedLinRModuleHost( + epochs=epochs, + batch_size=batch_size, + optimizer_param=optimizer, + learning_rate_param=learning_rate_scheduler, + init_param=init_param + ) + module.fit(fold_ctx, train_data, validate_data) + sub_ctx = fold_ctx.sub_ctx("predict_train") + module.predict(sub_ctx, train_data) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + module.predict(sub_ctx, validate_data) + i += 1 + + def train_guest(ctx, train_data, validate_data, train_output_data, output_model, epochs, batch_size, optimizer_param, learning_rate_param, init_param): logger.info(f"coordinated linr guest start train") - from fate.ml.glm import CoordinatedLinRModuleGuest - # optimizer = optimizer_factory(optimizer_param) sub_ctx = ctx.sub_ctx("train") module = CoordinatedLinRModuleGuest(epochs=epochs, batch_size=batch_size, optimizer_param=optimizer_param, learning_rate_param=learning_rate_param, @@ -125,10 +231,6 @@ def train_guest(ctx, train_data, validate_data, train_output_data, output_model, def train_host(ctx, train_data, validate_data, train_output_data, output_model, epochs, batch_size, optimizer_param, learning_rate_param, init_param): logger.info(f"coordinated linr host start train") - - from fate.ml.glm import CoordinatedLinRModuleHost - # optimizer = optimizer_factory(optimizer_param) - sub_ctx = ctx.sub_ctx("train") module = CoordinatedLinRModuleHost(epochs=epochs, batch_size=batch_size, optimizer_param=optimizer_param, learning_rate_param=learning_rate_param, @@ -151,8 +253,6 @@ def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_param, output_model): logger.info(f"coordinated linr arbiter start train") - from fate.ml.glm import CoordinatedLinRModuleArbiter - sub_ctx = ctx.sub_ctx("train") module = CoordinatedLinRModuleArbiter(epochs=epochs, early_stop=early_stop, tol=tol, batch_size=batch_size, optimizer_param=optimizer_param, learning_rate_param=learning_rate_param, @@ -165,8 +265,6 @@ def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, def predict_guest(ctx, input_model, test_data, test_output_data): - from fate.ml.glm import CoordinatedLinRModuleGuest - sub_ctx = ctx.sub_ctx("predict") model = input_model.read() @@ -178,8 +276,6 @@ def predict_guest(ctx, input_model, test_data, test_output_data): def predict_host(ctx, input_model, test_data, test_output_data): - from fate.ml.glm import CoordinatedLinRModuleHost - sub_ctx = ctx.sub_ctx("predict") model = input_model.read() module = CoordinatedLinRModuleHost.from_model(model) @@ -196,7 +292,4 @@ def transform_to_predict_result(test_data, predict_score, data_type="test"): v[0], json.dumps({"label": v[0]}), data_type], enable_type_align_checking=False) - # temp code start - df.rename(label_name="label") - # temp code end return df diff --git a/python/fate/components/components/coordinated_lr.py b/python/fate/components/components/coordinated_lr.py index 6194e7aa6b..2fab18751f 100644 --- a/python/fate/components/components/coordinated_lr.py +++ b/python/fate/components/components/coordinated_lr.py @@ -45,7 +45,8 @@ def train( ), epochs: cpn.parameter(type=params.conint(gt=0), default=20, desc="max iteration num"), batch_size: cpn.parameter( - type=params.conint(ge=-1), default=100, desc="batch size, " "value less or equals to 0 means full batch" + type=params.conint(ge=10), + default=None, desc="batch size, None means full batch, otherwise should be no less than 10, default None" ), optimizer: cpn.parameter( type=params.optimizer_param(), @@ -137,7 +138,8 @@ def cross_validation( ), epochs: cpn.parameter(type=params.conint(gt=0), default=20, desc="max iteration num"), batch_size: cpn.parameter( - type=params.conint(ge=-1), default=100, desc="batch size, " "value less or equals to 0 means full batch" + type=params.conint(ge=10), + default=None, desc="batch size, None means full batch, otherwise should be no less than 10, default None" ), optimizer: cpn.parameter( type=params.optimizer_param(), @@ -170,8 +172,8 @@ def cross_validation( learning_rate_scheduler = learning_rate_scheduler.dict() init_param = init_param.dict() # temp code end - i = 0 if role.is_arbiter: + i = 0 for fold_ctx, _ in ctx.on_cross_validations.ctxs_zip(zip(range(cv_param.n_splits))): logger.info(f"enter fold {i}") module = CoordinatedLRModuleArbiter( @@ -183,13 +185,14 @@ def cross_validation( learning_rate_param=learning_rate_scheduler, ) module.fit(fold_ctx) + i += 1 + return from fate.arch.dataframe import KFold kf = KFold(ctx, role=role, n_splits=cv_param.n_splits, shuffle=cv_param.shuffle, random_state=cv_param.random_state) i = 0 for fold_ctx, (train_data, validate_data) in ctx.on_cross_validations.ctxs_zip(kf.split(cv_data.read())): logger.info(f"enter fold {i}") - logger.info(f"train_data schema: {train_data.schema}, columns: {train_data.schema.columns}") if role.is_guest: module = CoordinatedLRModuleGuest( epochs=epochs, @@ -382,7 +385,4 @@ def transform_to_predict_result(test_data, predict_score, labels, threshold=0.5, lambda v: [int(v[0] > threshold), v[0], json.dumps({1: v[0], 0: 1 - v[0]}), data_type], enable_type_align_checking=False, ) - # temp code start - df.rename(label_name="label") - # temp code end return df diff --git a/python/fate/ml/glm/hetero/coordinated_linr/guest.py b/python/fate/ml/glm/hetero/coordinated_linr/guest.py index c875e5938d..1a50a45ce6 100644 --- a/python/fate/ml/glm/hetero/coordinated_linr/guest.py +++ b/python/fate/ml/glm/hetero/coordinated_linr/guest.py @@ -108,8 +108,10 @@ def __init__( def fit_model(self, ctx, train_data, validate_data=None): coef_count = train_data.shape[1] + logger.debug(f"init param: {self.init_param}") if self.init_param.get("fit_intercept"): - train_data["intercept"] = 1 + logger.debug(f"add intercept to train data") + train_data["intercept"] = 1.0 w = self.w if self.w is None: w = initialize_param(coef_count, **self.init_param) @@ -174,6 +176,8 @@ def fit_model(self, ctx, train_data, validate_data=None): logger.debug(f"Finish training at {self.end_epoch}th epoch.") def predict(self, ctx, test_data): + if self.init_param.get("fit_intercept"): + test_data["intercept"] = 1.0 X = test_data.values.as_tensor() pred = torch.matmul(X, self.w) for h_pred in ctx.hosts.get("h_pred"): diff --git a/python/fate/ml/glm/hetero/coordinated_lr/guest.py b/python/fate/ml/glm/hetero/coordinated_lr/guest.py index 99e5586581..63096077e9 100644 --- a/python/fate/ml/glm/hetero/coordinated_lr/guest.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/guest.py @@ -175,10 +175,8 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None): loss = log2 - (1/N)*0.5*∑ywx + (1/N)*0.125*[∑(Wg*Xg)^2 + ∑(Wh*Xh)^2 + 2 * ∑(Wg*Xg * Wh*Xh)] """ coef_count = train_data.shape[1] - logger.info(f"coef count: {coef_count}") if self.init_param.get("fit_intercept"): - if "intercept" not in train_data.schema.columns: - train_data["intercept"] = 1.0 + train_data["intercept"] = 1.0 w = self.w if w is None: @@ -205,7 +203,6 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None): weight = batch_data.weight h = X.shape[0] # logger.info(f"h: {h}") - logger.info(f"w: {w.detach()}, X shape: {X.shape}") Xw = torch.matmul(X, w.detach()) d = 0.25 * Xw - 0.5 * Y loss = 0.125 / h * torch.matmul(Xw.T, Xw) - 0.5 / h * torch.matmul(Xw.T, Y) @@ -220,7 +217,7 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None): loss -= 0.5 / h * torch.matmul(Y.T, Xw_h) loss += 0.25 / h * torch.matmul(Xw.T, Xw_h) if weight: - logger.info(f"weight: {weight.tolist()}") + # logger.info(f"weight: {weight.tolist()}") d = d * weight batch_ctx.hosts.put(d=d) diff --git a/python/fate/ml/utils/_model_param.py b/python/fate/ml/utils/_model_param.py index 2f77f0912b..e57f9cc29a 100644 --- a/python/fate/ml/utils/_model_param.py +++ b/python/fate/ml/utils/_model_param.py @@ -13,21 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging - import torch -logger = logging.getLogger(__name__) - def initialize_param(coef_len, **kwargs): param_len = coef_len method = kwargs["method"] fit_intercept = kwargs["fit_intercept"] - logger.info(f"kwargs: {kwargs}") if fit_intercept: param_len = param_len + 1 - logger.info(f"intercept added: param len {param_len}") if method == 'zeros': return torch.zeros((param_len, 1), requires_grad=True) elif method == 'ones': From 94d798a9ff84e33a5b2ac45d8dea36b48fe6ca95 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Thu, 13 Jul 2023 16:36:37 +0800 Subject: [PATCH 33/61] fix optimizer(#4659) fix dataloader for default None batch size Signed-off-by: Yu Wu --- python/fate/arch/dataframe/utils/_dataloader.py | 2 +- python/fate/ml/utils/_optimizer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/fate/arch/dataframe/utils/_dataloader.py b/python/fate/arch/dataframe/utils/_dataloader.py index 6fae9937f8..2448b2f666 100644 --- a/python/fate/arch/dataframe/utils/_dataloader.py +++ b/python/fate/arch/dataframe/utils/_dataloader.py @@ -83,7 +83,7 @@ def __init__(self, dataset, ctx, mode, role, batch_size, shuffle, random_seed, n self._mode = mode self._role = role self._batch_size = batch_size - if self._batch_size < 0 and self._role != "arbiter": + if self._batch_size is None and self._role != "arbiter": self._batch_size = len(self._dataset) self._shuffle = shuffle self._random_seed = random_seed diff --git a/python/fate/ml/utils/_optimizer.py b/python/fate/ml/utils/_optimizer.py index c586443523..64a70508dc 100644 --- a/python/fate/ml/utils/_optimizer.py +++ b/python/fate/ml/utils/_optimizer.py @@ -94,7 +94,7 @@ def get_delta_gradients(self): # logger.info(f"gradient: {self.model_parameter.grad}, prev model parameter: {self.prev_model_parameter}," # f"delta grad: {self.model_parameter - self.prev_model_parameter}") if self.prev_model_parameter is not None: - return self.model_parameter.data - self.prev_model_parameter + return self.prev_model_parameter - self.model_parameter.data else: raise ValueError(f"No optimization history found, please check.") From 81f07ef689d845a659eaac86654a4c5351491ed9 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Thu, 13 Jul 2023 16:54:55 +0800 Subject: [PATCH 34/61] make output cv data optional(#4659) Signed-off-by: Yu Wu --- .../components/components/coordinated_linr.py | 37 +++++++++-------- .../components/components/coordinated_lr.py | 41 ++++++++++--------- 2 files changed, 42 insertions(+), 36 deletions(-) diff --git a/python/fate/components/components/coordinated_linr.py b/python/fate/components/components/coordinated_linr.py index 6b71a548c0..9a83207d95 100644 --- a/python/fate/components/components/coordinated_linr.py +++ b/python/fate/components/components/coordinated_linr.py @@ -133,7 +133,8 @@ def cross_validation( default=params.CVParam(n_splits=5, shuffle=False, random_state=None), desc="cross validation param"), metrics: cpn.parameter(type=params.metrics_param(), default=["mse"]), - cv_output_datas: cpn.dataframe_outputs(roles=[GUEST, HOST]), + output_cv_data: cpn.parameter(type=bool, default=True, desc="whether output prediction result per cv fold"), + cv_output_datas: cpn.dataframe_outputs(roles=[GUEST, HOST], optional=True), ): # temp code start optimizer = optimizer.dict() @@ -170,18 +171,19 @@ def cross_validation( init_param=init_param ) module.fit(fold_ctx, train_data, validate_data) - sub_ctx = fold_ctx.sub_ctx("predict_train") - predict_score = module.predict(sub_ctx, train_data) - train_predict_result = transform_to_predict_result( - train_data, predict_score, data_type="train" - ) - sub_ctx = fold_ctx.sub_ctx("predict_validate") - predict_score = module.predict(sub_ctx, validate_data) - validate_predict_result = transform_to_predict_result( - validate_data, predict_score, data_type="predict" - ) - predict_result = DataFrame.vstack([train_predict_result, validate_predict_result]) - next(cv_output_datas).write(df=predict_result) + if output_cv_data: + sub_ctx = fold_ctx.sub_ctx("predict_train") + predict_score = module.predict(sub_ctx, train_data) + train_predict_result = transform_to_predict_result( + train_data, predict_score, data_type="train" + ) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + predict_score = module.predict(sub_ctx, validate_data) + validate_predict_result = transform_to_predict_result( + validate_data, predict_score, data_type="predict" + ) + predict_result = DataFrame.vstack([train_predict_result, validate_predict_result]) + next(cv_output_datas).write(df=predict_result) # evaluation = evaluate(predicted) elif role.is_host: @@ -193,10 +195,11 @@ def cross_validation( init_param=init_param ) module.fit(fold_ctx, train_data, validate_data) - sub_ctx = fold_ctx.sub_ctx("predict_train") - module.predict(sub_ctx, train_data) - sub_ctx = fold_ctx.sub_ctx("predict_validate") - module.predict(sub_ctx, validate_data) + if output_cv_data: + sub_ctx = fold_ctx.sub_ctx("predict_train") + module.predict(sub_ctx, train_data) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + module.predict(sub_ctx, validate_data) i += 1 diff --git a/python/fate/components/components/coordinated_lr.py b/python/fate/components/components/coordinated_lr.py index 2fab18751f..02e0b4c03c 100644 --- a/python/fate/components/components/coordinated_lr.py +++ b/python/fate/components/components/coordinated_lr.py @@ -165,7 +165,8 @@ def cross_validation( default=params.CVParam(n_splits=5, shuffle=False, random_state=None), desc="cross validation param"), metrics: cpn.parameter(type=params.metrics_param(), default=["auc"]), - cv_output_datas: cpn.dataframe_outputs(roles=[GUEST, HOST]), + output_cv_data: cpn.parameter(type=bool, default=True, desc="whether output prediction result per cv fold"), + cv_output_datas: cpn.dataframe_outputs(roles=[GUEST, HOST], optional=True), ): # temp code start optimizer = optimizer.dict() @@ -203,20 +204,21 @@ def cross_validation( threshold=threshold, ) module.fit(fold_ctx, train_data, validate_data) - sub_ctx = fold_ctx.sub_ctx("predict_train") - predict_score = module.predict(sub_ctx, train_data) - train_predict_result = transform_to_predict_result( - train_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, - data_type="train" - ) - sub_ctx = fold_ctx.sub_ctx("predict_validate") - predict_score = module.predict(sub_ctx, validate_data) - validate_predict_result = transform_to_predict_result( - validate_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, - data_type="predict" - ) - predict_result = DataFrame.vstack([train_predict_result, validate_predict_result]) - next(cv_output_datas).write(df=predict_result) + if output_cv_data: + sub_ctx = fold_ctx.sub_ctx("predict_train") + predict_score = module.predict(sub_ctx, train_data) + train_predict_result = transform_to_predict_result( + train_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, + data_type="train" + ) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + predict_score = module.predict(sub_ctx, validate_data) + validate_predict_result = transform_to_predict_result( + validate_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, + data_type="predict" + ) + predict_result = DataFrame.vstack([train_predict_result, validate_predict_result]) + next(cv_output_datas).write(df=predict_result) # evaluation = evaluate(predicted) elif role.is_host: @@ -228,10 +230,11 @@ def cross_validation( init_param=init_param, ) module.fit(fold_ctx, train_data, validate_data) - sub_ctx = fold_ctx.sub_ctx("predict_train") - module.predict(sub_ctx, train_data) - sub_ctx = fold_ctx.sub_ctx("predict_validate") - module.predict(sub_ctx, validate_data) + if output_cv_data: + sub_ctx = fold_ctx.sub_ctx("predict_train") + module.predict(sub_ctx, train_data) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + module.predict(sub_ctx, validate_data) i += 1 From ff4d4622c352b1faf2971beeaf589821b5a87836 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Thu, 13 Jul 2023 18:15:30 +0800 Subject: [PATCH 35/61] dataframe: support qcut and bucketize Signed-off-by: mgqa34 --- python/fate/arch/dataframe/_dataframe.py | 8 +++ python/fate/arch/dataframe/ops/_encoder.py | 76 ++++++++++++++++++++- python/fate/arch/dataframe/ops/_quantile.py | 16 +++++ 3 files changed, 99 insertions(+), 1 deletion(-) diff --git a/python/fate/arch/dataframe/_dataframe.py b/python/fate/arch/dataframe/_dataframe.py index 54649051e6..5b27cada4b 100644 --- a/python/fate/arch/dataframe/_dataframe.py +++ b/python/fate/arch/dataframe/_dataframe.py @@ -277,6 +277,14 @@ def quantile( from .ops._quantile import quantile return quantile(self, q, relative_error) + def qcut(self, q: int): + from .ops._quantile import qcut + return qcut(self, q) + + def bucketize(self, boundaries: Union[dict, pd.DataFrame]) -> "DataFrame": + from .ops._encoder import bucketize + return bucketize(self, boundaries) + def __add__(self, other: Union[int, float, list, "np.ndarray", "DataFrame", "pd.Series"]) -> "DataFrame": return self.__arithmetic_operate(operator.add, other) diff --git a/python/fate/arch/dataframe/ops/_encoder.py b/python/fate/arch/dataframe/ops/_encoder.py index f1d6eb7d84..ebb0a70863 100644 --- a/python/fate/arch/dataframe/ops/_encoder.py +++ b/python/fate/arch/dataframe/ops/_encoder.py @@ -13,10 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import functools + +import pandas as pd import numpy as np +import torch from sklearn.preprocessing import OneHotEncoder +from typing import Union from .._dataframe import DataFrame -from ..manager import BlockType +from ..manager import BlockType, DataManager + + +BUCKETIZE_RESULT_TYPE = "int32" def get_dummies(df: "DataFrame", dtype="int32"): @@ -98,3 +106,69 @@ def _encode(blocks): return ret_blocks return block_table.mapValues(_encode) + + +def bucketize(df: DataFrame, boundaries: Union[pd.DataFrame, dict]): + if isinstance(boundaries, pd.DataFrame): + boundaries = dict([(_name, boundaries[_name].tolist()) for _name in boundaries]) + elif not isinstance(boundaries, dict): + raise ValueError("boundaries should be pd.DataFrame or dict") + + data_manager = df.data_manager.duplicate() + field_names = list(filter(lambda field_name: field_name in boundaries, data_manager.infer_operable_field_names())) + blocks_loc = data_manager.loc_block(field_names) + + _boundaries_list = [] + for name, (_bid, _) in zip(field_names, blocks_loc): + if BlockType.is_tensor(data_manager.blocks[_bid].block_type): + _boundary = torch.tensor(boundaries[name]) + _boundary[-1] = torch.inf + else: + _boundary = np.array(boundaries[name]) + _boundary[-1] = np.inf + + _boundaries_list.append((_bid, _, _boundary)) + + narrow_blocks, dst_blocks = data_manager.split_columns(field_names, BlockType.get_block_type(BUCKETIZE_RESULT_TYPE)) + + def _mapper(blocks, boundaries_list: list = None, narrow_loc: list = None, + dst_bids: list = None, dm: DataManager = None): + ret_blocks = [] + for block in blocks: + if isinstance(block, torch.Tensor): + ret_blocks.append(block.clone()) + elif isinstance(block, np.ndarray): + ret_blocks.append(block.copy()) + else: + ret_blocks.append(block) + + for i in range(len(ret_blocks), dm.block_num): + ret_blocks.append([]) + + for bid, offsets in narrow_loc: + ret_blocks[bid] = ret_blocks[bid][:, offsets] + + for dst_bid, (src_bid, src_offset, boundary) in zip(dst_bids, boundaries_list): + if isinstance(blocks[src_bid], torch.Tensor): + ret = torch.bucketize(blocks[src_bid][:, [src_offset]], boundary, out_int32=False) + else: + ret = torch.bucketize(blocks[src_bid][:, [src_offset]], boundary) + + ret_blocks[dst_bid] = dm.blocks[dst_bid].convert_block(ret) + + return ret_blocks + + bucketize_mapper = functools.partial(_mapper, + boundaries_list=_boundaries_list, + narrow_loc=narrow_blocks, + dst_bids=dst_blocks, + dm=data_manager) + + block_table = df.block_table.mapValues(bucketize_mapper) + + return DataFrame( + df._ctx, + block_table, + partition_order_mappings=df.partition_order_mappings, + data_manager=data_manager + ) diff --git a/python/fate/arch/dataframe/ops/_quantile.py b/python/fate/arch/dataframe/ops/_quantile.py index e6063a7ff0..6226284e3d 100644 --- a/python/fate/arch/dataframe/ops/_quantile.py +++ b/python/fate/arch/dataframe/ops/_quantile.py @@ -56,3 +56,19 @@ def _reducer(l_gk_summary_obj_list, r_gk_summary_obj_list): quantile_df = pd.DataFrame(quantile_rets, index=q) return quantile_df + + +def qcut(df: DataFrame, q: int): + assert isinstance(q, int), f"to use qcut, {q} should be positive integer" + max_ret = df.max() + min_ret = df.min() + + dist = (max_ret - min_ret) / q + + cut_ret = [] + for i in range(1, q): + cut_ret.append(min_ret + i * dist) + + cut_ret.append(max_ret) + + return pd.DataFrame(cut_ret, index=range(1, q + 1, 1)) From 0bf468db67019f1347fd7866623b40ee4dcd5312 Mon Sep 17 00:00:00 2001 From: cwj Date: Thu, 13 Jul 2023 18:56:42 +0800 Subject: [PATCH 36/61] =?UTF-8?q?1.=20Update=20Evaluation:=20support=20col?= =?UTF-8?q?=20specify=20and=20multi=20default=20col=20update=202.=20Update?= =?UTF-8?q?=20the=20IO=20design=20of=20Homo-NN=203.=20Update=20Homo-LR?= =?UTF-8?q?=EF=BC=9Aall=20features=20supported=204.=20update=20predict=20t?= =?UTF-8?q?ools:=20from=20component=20level=20to=20ml=20level?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: cwj --- fate_client | 2 +- .../fate/components/components/evaluation.py | 55 ++- python/fate/components/components/homo_lr.py | 33 +- python/fate/components/components/homo_nn.py | 134 +++---- .../components/components/nn/nn_runner.py | 344 ++++++------------ .../components/nn/runner/default_runner.py | 269 +++++++------- .../components/nn/test/test_default_runner.py | 60 +++ .../components/utils/predict_format.py | 91 ----- .../fate/components/components/utils/tools.py | 8 +- python/fate/ml/evaluation/metric_base.py | 5 +- python/fate/ml/glm/__init__.py | 2 + .../ml/glm/{homo_lr => homo/lr}/client.py | 208 +++++------ .../ml/glm/{homo_lr => homo/lr}/server.py | 0 .../{homo_lr => homo/lr}/test/local_test.py | 19 +- python/fate/ml/nn/dataset/base.py | 174 +-------- python/fate/ml/nn/dataset/table.py | 216 ++++++----- python/fate/ml/nn/trainer/trainer_base.py | 2 - python/fate/ml/utils/predict_format.py | 84 ----- python/fate/ml/utils/predict_tools.py | 112 ++++++ 19 files changed, 780 insertions(+), 1038 deletions(-) create mode 100644 python/fate/components/components/nn/test/test_default_runner.py delete mode 100644 python/fate/components/components/utils/predict_format.py rename python/fate/ml/glm/{homo_lr => homo/lr}/client.py (68%) rename python/fate/ml/glm/{homo_lr => homo/lr}/server.py (100%) rename python/fate/ml/glm/{homo_lr => homo/lr}/test/local_test.py (65%) delete mode 100644 python/fate/ml/utils/predict_format.py create mode 100644 python/fate/ml/utils/predict_tools.py diff --git a/fate_client b/fate_client index 7fcb28c933..ab81005987 160000 --- a/fate_client +++ b/fate_client @@ -1 +1 @@ -Subproject commit 7fcb28c93331cd50e285e4e6c1bcd7ec8f8b896e +Subproject commit ab81005987e43a6771b6316007e5e81c20480669 diff --git a/python/fate/components/components/evaluation.py b/python/fate/components/components/evaluation.py index 3c560276e2..27fc49ea98 100644 --- a/python/fate/components/components/evaluation.py +++ b/python/fate/components/components/evaluation.py @@ -26,7 +26,8 @@ get_regression_metrics, get_specified_metrics, ) -from fate.components.components.utils.predict_format import PREDICT_SCORE, LABEL +from fate.ml.utils.predict_tools import PREDICT_SCORE, PREDICT_RESULT, LABEL +from fate.components.components.utils.consts import BINARY, REGRESSION, MULTI logger = logging.getLogger(__name__) @@ -44,49 +45,65 @@ def evaluation( ctx: Context, role: Role, input_data: cpn.dataframe_inputs(roles=[GUEST, HOST]), - default_eval_metrics: cpn.parameter( + default_eval_setting: cpn.parameter( type=string_choice(choice=["binary", "multi", "regression"]), default="binary", optional=True ), - metrics: cpn.parameter(type=list, default=None, optional=True) + metrics: cpn.parameter(type=list, default=None, optional=True), + predict_column_name: cpn.parameter(type=str, default=None, optional=True, + desc="predict data column name, if None(default), will use \ + 'predict_score' when use binary and regression default setting, \ + and use 'predict_result' on multi classification default setting"), + label_column_name: cpn.parameter(type=str, default=None, optional=True, desc="label data column namem if None(default), \ + will use 'label' in the input dataframe") ): if role.is_arbiter: return else: + if metrics is not None: metrics_ensemble = get_specified_metrics(metrics) + predict_col = predict_column_name if predict_column_name is not None else PREDICT_SCORE + label_col = label_column_name if label_column_name is not None else LABEL else: - if default_eval_metrics == "binary": - metrics_ensemble = get_binary_metrics() - elif default_eval_metrics == "multi": + if default_eval_setting == MULTI: metrics_ensemble = get_multi_metrics() - elif default_eval_metrics == "regression": - metrics_ensemble = get_regression_metrics() + predict_col = predict_column_name if predict_column_name is not None else PREDICT_RESULT + label_col = label_column_name if label_column_name is not None else LABEL + else: + if default_eval_setting == BINARY: + metrics_ensemble = get_binary_metrics() + elif default_eval_setting == REGRESSION: + metrics_ensemble = get_regression_metrics() + else: + raise ValueError("default_eval_setting should be one of binary, multi, regression, got {}") + predict_col = predict_column_name if predict_column_name is not None else PREDICT_SCORE + label_col = label_column_name if label_column_name is not None else LABEL df_list = [_input.read() for _input in input_data] - component_name = [_input.artifact.metadata.source.component for _input in input_data] - component_rs = {} - for name, df in zip(component_name, df_list): - rs_dict = evaluate(df, metrics_ensemble) - component_rs[name] = rs_dict + task_names = [_input.artifact.metadata.source.task_name for _input in input_data] + eval_rs = {} + logger.info('components names are {}'.format(task_names)) + for name, df in zip(task_names, df_list): + rs_ = evaluate(df, metrics_ensemble, predict_col, label_col) + eval_rs[name] = rs_ - ctx.metrics.log_metrics(rs_dict, name='evaluation', type='evaluation') - logger.info("eval result: {}".format(rs_dict)) + ctx.metrics.log_metrics(eval_rs, name='evaluation', type='evaluation') + logger.info("eval result: {}".format(eval_rs)) -def evaluate(input_data, metrics): +def evaluate(input_data, metrics, predict_col, label_col): data = input_data.as_pd_df() split_dict = split_dataframe_by_type(data) rs_dict = {} logger.info('eval dataframe is {}'.format(data)) - for name, df in split_dict.items(): logger.info('eval dataframe is \n\n{}'.format(df)) - y_true = df[LABEL] + y_true = df[label_col] # in case is multi result, use tolist - y_pred = df[PREDICT_SCORE] + y_pred = df[predict_col] rs = metrics(predict=y_pred, label=y_true) rs_dict[name] = rs diff --git a/python/fate/components/components/homo_lr.py b/python/fate/components/components/homo_lr.py index 618b74c1d0..fa988d5100 100644 --- a/python/fate/components/components/homo_lr.py +++ b/python/fate/components/components/homo_lr.py @@ -14,12 +14,13 @@ # limitations under the License. import logging from fate.arch import Context -from fate.ml.glm.homo_lr.client import HomoLRClient -from fate.ml.glm.homo_lr.server import HomoLRServer +from fate.ml.glm.homo.lr.client import HomoLRClient +from fate.ml.glm.homo.lr.server import HomoLRServer from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params from fate.components.components.utils import consts from fate.ml.utils.model_io import ModelIO - +from fate.components.components.utils.tools import add_dataset_type +from fate.arch.dataframe import DataFrame logger = logging.getLogger(__name__) @@ -37,8 +38,8 @@ def train( train_data: cpn.dataframe_input(roles=[GUEST, HOST]), validate_data: cpn.dataframe_input(roles=[GUEST, HOST], optional=True), learning_rate_scheduler: cpn.parameter(type=params.lr_scheduler_param(), - default=params.LRSchedulerParam(method="linear", - scheduler_params={"start_factor": 1.0}), + default=params.LRSchedulerParam(method="constant", + scheduler_params={"factor": 1.0}), desc="learning rate scheduler, " "select method from {'step', 'linear', 'constant'}" "for list of configurable arguments, " @@ -52,12 +53,12 @@ def train( default=params.OptimizerParam(method="sgd", penalty='l2', alpha=1.0, optimizer_params={"lr": 1e-2, "weight_decay": 0})), init_param: cpn.parameter(type=params.init_param(), - default=params.InitParam(method='zeros', fit_intercept=True), + default=params.InitParam(method='random', fit_intercept=True), desc="Model param init setting."), threshold: cpn.parameter(type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data"), ovr: cpn.parameter(type=bool, default=False, - desc="predict threshold for binary data"), + desc="enable ovr for multi-classifcation"), label_num: cpn.parameter(type=params.conint(ge=2), default=None), train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), train_input_model: cpn.json_model_input(roles=[GUEST, HOST], optional=True), @@ -69,19 +70,26 @@ def train( if role.is_guest or role.is_host: # is client logger.info('homo lr component: client start training') - logger.info('optim param {} init param {}'.format(optimizer.dict(), init_param.dict())) + logger.info('optim param {} \n init param {} \n learning rate param {}'.format(optimizer.dict(), init_param.dict(), learning_rate_scheduler.dict())) client = HomoLRClient(epochs=epochs, batch_size=batch_size, optimizer_param=optimizer.dict(), init_param=init_param.dict(), - learning_rate_scheduler=0.01, threshold=threshold, ovr=ovr, label_num=label_num) + learning_rate_scheduler=learning_rate_scheduler.dict(), threshold=threshold, ovr=ovr, label_num=label_num) + + if train_input_model is not None: + model_input = train_input_model.read() + client.from_model(model_input) + logger.info('model input loaded') train_df = train_data.read() validate_df = validate_data.read() if validate_data else None client.fit(sub_ctx, train_df, validate_df) model_dict = client.get_model().dict() train_rs = client.predict(sub_ctx, train_df) + train_rs = add_dataset_type(train_rs, consts.TRAIN_SET) if validate_df: validate_rs = client.predict(sub_ctx, validate_df) - ret_df = train_rs.vstack(validate_rs) + validate_rs = add_dataset_type(validate_rs, consts.VALIDATE_SET) + ret_df = DataFrame.vstack([train_rs, validate_rs]) else: ret_df = train_rs @@ -113,9 +121,10 @@ def predict( client = HomoLRClient(batch_size=batch_size, threshold=threshold) model_input = predict_input_model.read() model_data = ModelIO.from_dict(model_input) - logger.info('model input is {}'.format(model_input)) + client.from_model(model_data) pred_rs = client.predict(ctx, test_data.read()) - + pred_rs = add_dataset_type(pred_rs, consts.TEST_SET) + test_output_data.write(pred_rs) elif role.is_arbiter: # is server logger.info("arbiter skip predict") diff --git a/python/fate/components/components/homo_nn.py b/python/fate/components/components/homo_nn.py index 6da14833b6..4b837b4f93 100644 --- a/python/fate/components/components/homo_nn.py +++ b/python/fate/components/components/homo_nn.py @@ -14,16 +14,14 @@ # limitations under the License. import logging import os - -import pandas as pd from fate.arch import Context -from fate.arch.dataframe import PandasReader from fate.components.components.nn.loader import Loader -from fate.components.components.nn.nn_runner import NNInput, NNOutput, NNRunner +from fate.components.components.nn.nn_runner import NNRunner from fate.components.components.nn.runner.default_runner import DefaultRunner from fate.components.components.utils import consts -from fate.components.components.utils.predict_format import LABEL from fate.components.core import ARBITER, GUEST, HOST, Role, cpn +from fate.arch.dataframe import DataFrame +from fate.components.components.utils.tools import add_dataset_type logger = logging.getLogger(__name__) @@ -60,26 +58,22 @@ def prepare_context_and_role(runner, ctx, role, sub_ctx_name): sub_ctx = ctx.sub_ctx(sub_ctx_name) runner.set_context(sub_ctx) runner.set_role(role) - return sub_ctx -def get_input_data(stage, cpn_input_data, fate_save_path='./', saved_model_path=None, - input_type='df',): +def get_input_data(stage, cpn_input_data): + if stage == 'train': train_data, validate_data = cpn_input_data - if input_type == "df": - train_data = train_data.read() - if validate_data is not None: - validate_data = validate_data.read() + train_data = train_data.read() + if validate_data is not None: + validate_data = validate_data.read() - return NNInput(train_data=train_data, validate_data=validate_data, - fate_save_path=fate_save_path, saved_model_path=saved_model_path) + return train_data, validate_data elif stage == 'predict': test_data = cpn_input_data test_data = test_data.read() - return NNInput(test_data=test_data, - fate_save_path=fate_save_path, saved_model_path=saved_model_path) + return test_data else: raise ValueError(f"Unknown stage {stage}") @@ -92,63 +86,19 @@ def get_model_output_conf(runner_module, runner_class, runner_conf, source, - model_output_path ): return { "runner_module": runner_module, "runner_class": runner_class, "runner_conf": runner_conf, "source": source, - "saved_model_path": model_output_path, } -def write_output_df(ctx, result_df: pd.DataFrame, output_data_cls, match_id_name, sample_id_name): - - reader = PandasReader(sample_id_name=sample_id_name, match_id_name=match_id_name, label_name=LABEL, dtype="object") - data = reader.to_frame(ctx, result_df) - output_data_cls.write(data) - - -def handle_nn_output(ctx, nn_output: NNOutput, output_class, stage): - - if nn_output is None: - logger.warning("runner output is None in stage:{}, skip processing".format(stage)) - - elif isinstance(nn_output, NNOutput): - if stage == consts.TRAIN: - - if nn_output.train_result is None and nn_output.validate_result is None: - raise ValueError( - "train result and validate result are both None in the NNOutput: {}".format(nn_output) - ) - - df_train, df_val = nn_output.train_result, nn_output.validate_result - - match_id_name, sample_id_name = nn_output.match_id_name, nn_output.sample_id_name - if df_train is not None and df_val is not None: - df_train_val = pd.concat([df_train, df_val], axis=0) - df_train_val.match_id_name = df_train.match_id_name - write_output_df(ctx, df_train_val, output_class, match_id_name, sample_id_name) - elif df_train is not None: - write_output_df(ctx, df_train, output_class, match_id_name, sample_id_name) - elif df_val is not None: - write_output_df(ctx, df_val, output_class, match_id_name, sample_id_name) - if stage == consts.PREDICT: - if nn_output.test_result is None: - raise ValueError("test result not found in the NNOutput: {}".format(nn_output)) - write_output_df( - ctx, nn_output.test_result, output_class, nn_output.match_id_name, nn_output.sample_id_name - ) - else: - logger.warning("train output is not NNOutput, but {}, fail to output dataframe".format(type(nn_output))) - def prepared_saved_conf(model_conf, runner_class, runner_module, runner_conf, source): logger.info("loaded model_conf is: {}".format(model_conf)) - if "saved_model_path" in model_conf: - saved_model_path = model_conf["saved_model_path"] if "source" in model_conf: if source is None: source = model_conf["source"] @@ -165,7 +115,7 @@ def prepared_saved_conf(model_conf, runner_class, runner_module, runner_conf, so "use the new runner_conf, runner_class and runner module to train the model,\ saved module & class: {} {}, new module & class: {} {}".format(runner_module_, runner_class_, runner_module, runner_class)) - return runner_conf, source, runner_class, runner_module, saved_model_path + return runner_conf, source, runner_class, runner_module @cpn.component(roles=[GUEST, HOST, ARBITER]) @@ -183,32 +133,50 @@ def train( runner_class: cpn.parameter(type=str, default="DefaultRunner", desc="class name of your runner class"), runner_conf: cpn.parameter(type=dict, default={}, desc="the parameter dict of the NN runner class"), source: cpn.parameter(type=str, default=None, desc="path to your runner script folder"), - train_data_output: cpn.dataframe_output(roles=[GUEST, HOST]), - train_model_output: cpn.model_directory_output(roles=[GUEST, HOST]), + train_data_output: cpn.dataframe_output(roles=[GUEST, HOST], optional=True), + train_model_output: cpn.model_directory_output(roles=[GUEST, HOST], optional=True), train_model_input: cpn.model_directory_input(roles=[GUEST, HOST], optional=True), ): runner: NNRunner = prepare_runner_class(runner_module, runner_class, runner_conf, source) - sub_ctx = prepare_context_and_role(runner, ctx, role, consts.TRAIN) + prepare_context_and_role(runner, ctx, role, consts.TRAIN) if role.is_guest or role.is_host: # is client - saved_model_path=None if train_model_input is not None: model_conf = train_model_input.get_metadata() - runner_conf, source, runner_class, runner_module, saved_model_path = prepared_saved_conf(model_conf, runner_class, runner_module, runner_conf, source) + runner_conf, source, runner_class, runner_module = prepared_saved_conf(model_conf, runner_class, runner_module, runner_conf, source) + saved_model_path = str(train_model_input.get_directory()) + else: + saved_model_path = None + + output_dir = str(train_model_output.get_directory()) + train_data_, validate_data_ = get_input_data(consts.TRAIN, [train_data, validate_data]) + runner.train(train_data_, validate_data_, output_dir, saved_model_path) + + logger.info('Predicting Train & Validate Data') + train_pred = runner.predict(train_data_, saved_model_path) + if train_pred is not None: + assert isinstance(train_pred, DataFrame), "train predict result should be a DataFrame" + add_dataset_type(train_pred, consts.TRAIN_SET) + + if validate_data_ is not None: + validate_pred = runner.predict(validate_data_) + assert isinstance(validate_pred, DataFrame), "validate predict result should be a DataFrame" + add_dataset_type(validate_pred, consts.VALIDATE_SET) + output_df = DataFrame.vstack([train_pred, validate_pred]) + else: + output_df = train_pred + logger.info('write result dataframe') + train_data_output.write(output_df) + else: + logger.warning("train_pred is None, It seems that the runner is not able to predict. Failed to output data") - output_path = train_model_output.get_directory() - input_data = get_input_data(consts.TRAIN, [train_data, validate_data], output_path, saved_model_path) - ret: NNOutput = runner.train(input_data=input_data) - logger.info("train result: {}".format(ret)) - handle_nn_output(sub_ctx, ret, train_data_output, consts.TRAIN) output_conf = get_model_output_conf(runner_module, runner_class, runner_conf, - source, - output_path) - logger.info("output_path: {}".format(output_conf)) + source + ) train_model_output.write_metadata(output_conf) elif role.is_arbiter: # is server @@ -221,7 +189,7 @@ def predict( role: Role, test_data: cpn.dataframe_input(roles=[GUEST, HOST]), predict_model_input: cpn.model_directory_input(roles=[GUEST, HOST]), - predict_data_output: cpn.dataframe_output(roles=[GUEST, HOST]) + predict_data_output: cpn.dataframe_output(roles=[GUEST, HOST], optional=True) ): if role.is_guest or role.is_host: # is client @@ -231,13 +199,17 @@ def predict( runner_class = model_conf['runner_class'] runner_conf = model_conf['runner_conf'] source = model_conf['source'] - saved_model_path = model_conf["saved_model_path"] - + saved_model_path = str(predict_model_input.get_directory()) + test_data_ = get_input_data(consts.PREDICT, test_data) runner: NNRunner = prepare_runner_class(runner_module, runner_class, runner_conf, source) - sub_ctx = prepare_context_and_role(runner, ctx, role, consts.PREDICT) - input_data = get_input_data(consts.PREDICT, test_data, saved_model_path=saved_model_path) - ret: NNOutput = runner.predict(input_data) - handle_nn_output(sub_ctx, ret, predict_data_output, consts.PREDICT) + prepare_context_and_role(runner, ctx, role, consts.PREDICT) + test_pred = runner.predict(test_data_, saved_model_path=saved_model_path) + if test_pred is not None: + assert isinstance(test_pred, DataFrame), "test predict result should be a DataFrame" + add_dataset_type(test_pred, consts.TEST_SET) + predict_data_output.write(test_pred) + else: + logger.warning("test_pred is None, It seems that the runner is not able to predict. Failed to output data") elif role.is_arbiter: # is server logger.info("arbiter skip predict") diff --git a/python/fate/components/components/nn/nn_runner.py b/python/fate/components/components/nn/nn_runner.py index c8d20dc54d..2167b0c29c 100644 --- a/python/fate/components/components/nn/nn_runner.py +++ b/python/fate/components/components/nn/nn_runner.py @@ -1,26 +1,22 @@ import numpy as np import torch import pandas as pd -from typing import Union, Optional +from typing import Union, Optional, Literal from fate.components.core import Role from fate.arch import Context -from typing import Optional, Callable, Tuple +from typing import Optional, Union from transformers.trainer_utils import PredictionOutput import numpy as np from fate.arch.dataframe._dataframe import DataFrame from fate.arch.dataframe.manager.schema_manager import Schema from fate.components.components.utils import consts -from fate.components.components.utils.predict_format import get_output_pd_df, LABEL, PREDICT_SCORE import logging +from fate.ml.utils.predict_tools import to_fate_df, std_output_df, add_ids logger = logging.getLogger(__name__) -FATE_DF = 'fate_df' -STR_PATH = 'str_path' - - def _convert_to_numpy_array(data: Union[pd.Series, pd.DataFrame, np.ndarray, torch.Tensor]) -> np.ndarray: if isinstance(data, pd.Series) or isinstance(data, pd.DataFrame): return data.to_numpy() @@ -28,171 +24,7 @@ def _convert_to_numpy_array(data: Union[pd.Series, pd.DataFrame, np.ndarray, tor return data.cpu().numpy() else: return np.array(data) - - -class SampleIDs: - - def __init__(self, sample_id=None, match_id=None, sample_id_name='sample_id', match_id_name='id') -> None: - self.sample_id = sample_id - self.match_id = match_id - self.sample_id_name = sample_id_name - self.match_id_name = match_id_name - - def maybe_generate_ids(self, sample_num: int) -> None: - if self.sample_id is None: - self.sample_id = np.arange(0, sample_num) - if self.match_id is None: - self.match_id = np.arange(0, sample_num) - - def get_id_df(self) -> pd.DataFrame: - return pd.DataFrame({self.sample_id_name: self.sample_id, self.match_id_name: self.match_id}) - - def __repr__(self) -> str: - return f"{self.sample_id_name}: {self.sample_id} \n {self.match_id_name}: {self.match_id}" - - -class NNInput: - """ - Class to encapsulate input data for NN Runner. - - Parameters: - train_data (Union[pd.DataFrame, str]): The training data as a pandas DataFrame or the file path to it. - validate_data (Union[pd.DataFrame, str]): The validation data as a pandas DataFrame or the file path to it. - test_data (Union[pd.DataFrame, str]): The testing data as a pandas DataFrame or the file path to it. - saved_model_path (str): The path of a saved model. - fate_save_path (str): The path for you to save your trained model in current task. - """ - - def __init__(self, train_data: Union[pd.DataFrame, str, DataFrame] = None, - validate_data: Union[pd.DataFrame, str, DataFrame] = None, - test_data: Union[pd.DataFrame, str, DataFrame] = None, - saved_model_path: str = None, - fate_save_path: str = None, - ) -> None: - - self.schema = None - self.train_ids = None - self.validate_ids = None - self.test_ids = None - self.input_type = None - - # training - - if isinstance(train_data, DataFrame): - self.train_data, self.train_ids, self.schema = self._extract_fate_df(train_data) - self.input_type = FATE_DF - else: - self.train_data = train_data - self.train_ids = SampleIDs() - self.input_type = STR_PATH - - if isinstance(validate_data, DataFrame): - self.validate_data, self.validate_ids, _ = self._extract_fate_df(validate_data) - else: - self.validate_data = validate_data - self.validate_ids = SampleIDs() - - # prediction - - if isinstance(test_data, DataFrame): - self.test_data, self.test_ids, self.schema = self._extract_fate_df(test_data) - self.input_type = FATE_DF - - else: - self.test_data = test_data - self.test_ids = SampleIDs() - self.input_type = STR_PATH - - self.saved_model_path = saved_model_path - self.fate_save_path = fate_save_path - - def _extract_fate_df(self, df: DataFrame): - schema = df.schema - pd_df = df.as_pd_df() - sample_id = schema.sample_id_name - match_id = schema.match_id_name - ids = SampleIDs(sample_id=pd_df[sample_id].to_numpy(), match_id=pd_df[match_id].to_numpy(), - sample_id_name=sample_id, match_id_name=match_id) - features = pd_df.drop(columns=[sample_id, match_id]) - return features, ids, schema - - def get(self, key: str) -> Union[pd.DataFrame, str]: - return getattr(self, key) - - def get_train_data(self) -> Union[pd.DataFrame, str]: - return self.train_data - - def get_validate_data(self) -> Union[pd.DataFrame, str]: - return self.validate_data - - def get_test_data(self) -> Union[pd.DataFrame, str]: - return self.test_data - - def get_saved_model_path(self) -> str: - return self.saved_model_path - - def get_fate_save_path(self) -> str: - return self.fate_save_path - - def get_train_ids(self) -> SampleIDs: - return self.train_ids - - def get_validate_ids(self) -> SampleIDs: - return self.validate_ids - - def get_test_ids(self) -> SampleIDs: - return self.test_ids - - def get_schema(self) -> Schema: - return self.schema - - def __getitem__(self, key: str): - return self.get(key) - def __repr__(self) -> str: - return f"NNInput(\ntrain_data={self.train_data},\nvalidate_data={self.validate_data}, \ - \ntest_data={self.test_data},\nmodel_path={self.saved_model_path},\nfate_save_path={self.fate_save_path}\n)" - - - -class NNOutput: - - def __init__(self, - train_result: Optional[pd.DataFrame] = None, - validate_result: Optional[pd.DataFrame] = None, - test_result: Optional[pd.DataFrame] = None, - sample_id_name="sample_id", - match_id_name="id", - ) -> None: - - assert isinstance(train_result, pd.DataFrame) or train_result is None - assert isinstance(validate_result, pd.DataFrame) or validate_result is None - assert isinstance(test_result, pd.DataFrame) or test_result is None - self.sample_id_name = sample_id_name - self.match_id_name = match_id_name - - self.train_result = self._check_ids(train_result) - self.validate_result = self._check_ids(validate_result) - self.test_result = self._check_ids(test_result) - - def _check_ids(self, dataframe: pd.DataFrame): - if dataframe is None: - return None - if self.sample_id_name in dataframe.columns and self.match_id_name in dataframe.columns: - return dataframe - id_ = SampleIDs(sample_id_name=self.sample_id_name, match_id_name=self.match_id_name) - id_.maybe_generate_ids(len(dataframe)) - id_df = id_.get_id_df() - if self.sample_id_name not in dataframe.columns: - # concat id_df and dataframe - dataframe = pd.concat([id_df[[self.sample_id_name]], dataframe], axis=1) - if self.match_id_name not in dataframe.columns: - dataframe = pd.concat([id_df[[self.match_id_name]], dataframe], axis=1) - return dataframe - - def __repr__(self) -> str: - return f"NNOutput(train_result=\n{self.train_result}\n, validate_result=\n{self.validate_result}\n, test_result=\n{self.test_result}\n)" - def task_type_infer(predict_result, true_label): @@ -215,41 +47,6 @@ def task_type_infer(predict_result, true_label): return None -def get_formatted_output_df(predict_rs: PredictionOutput, id_: SampleIDs, dataset_type, task_type=None, - classes=None, threshold=0.5): - - logger.info("Start to format output dataframe {}".format(type(predict_rs))) - if isinstance(predict_rs, PredictionOutput): - predict_score = predict_rs.predictions - if hasattr(predict_rs, 'label_ids'): - label = predict_rs.label_ids - else: - raise ValueError("predict_rs should be PredictionOutput and label ids should be included in it, but got {}".format(predict_rs)) - - predict_score = _convert_to_numpy_array(predict_score) - label = _convert_to_numpy_array(label) - df = pd.DataFrame() - df[PREDICT_SCORE] = predict_score.tolist() - id_.maybe_generate_ids(len(df)) - id_df = id_.get_id_df() - df = pd.concat([id_df, df], axis=1) - if task_type is None: - task_type = task_type_infer(predict_score, label) - if task_type == consts.BINARY or task_type == consts.MULTI: - if task_type == consts.BINARY: - classes = [0, 1] - else: - classes = np.unique(label).tolist() - if task_type is not None: - return get_output_pd_df(df, label, id_.match_id_name, id_.sample_id_name, dataset_type, task_type, classes, threshold) - else: - df[LABEL] = label - return df - else: - raise ValueError("predict_rs should be PredictionOutput") - - - class NNRunner(object): def __init__(self) -> None: @@ -279,41 +76,116 @@ def set_party_id(self, party_id: int): assert isinstance(self._party_id, int) self._party_id = party_id - @staticmethod - def generate_std_nn_output(input_data: NNInput, - train_eval_prediction: Optional[PredictionOutput] = None, - validate_eval_prediction: Optional[PredictionOutput] = None, - test_eval_prediction: Optional[PredictionOutput] = None, - task_type: str = consts.BINARY, - threshold: float = 0.5) -> NNOutput: + def get_fateboard_tracker(self): + pass - results = {} - match_id_name, sample_id_name = 'id', 'sample_id' - if train_eval_prediction is not None: - ids = input_data.get_train_ids() - match_id_name, sample_id_name = ids.match_id_name, ids.sample_id_name - elif test_eval_prediction is not None: - ids = input_data.get_test_ids() - match_id_name, sample_id_name = ids.match_id_name, ids.sample_id_name + def get_nn_output_dataframe( + self, + ctx, + predictions: Union[np.ndarray, torch.Tensor, DataFrame, PredictionOutput], + labels: Union[np.ndarray, torch.Tensor, DataFrame, PredictionOutput], + match_ids: Union[pd.DataFrame, np.ndarray] = None, + sample_ids: Union[pd.DataFrame, np.ndarray] = None, + match_id_name: str = None, + sample_id_name: str = None, + dataframe_format: Literal['default', 'fate_std'] = 'default', + task_type: Literal['binary', 'multi', 'regression', 'others'] = None, + threshold: float = 0.5, + classes: list = None + )-> DataFrame: + """ + Constructs a FATE DataFrame from predictions and labels. This Dataframe is able to flow through FATE components. + + Parameters: + ctx (Context): The Context Instance. + predictions (Union[np.ndarray, torch.Tensor, DataFrame, PredictionOutput]): The model's predictions, which can be numpy arrays, torch tensors, pandas DataFrames, or PredictionOutputs. + labels (Union[np.ndarray, torch.Tensor, DataFrame, PredictionOutput]): The true labels, which can be numpy arrays, torch tensors, pandas DataFrames, or PredictionOutputs. + match_ids (Union[pd.DataFrame, np.ndarray], optional): Match IDs, if applicable. Defaults to None. If None, will auto generate match_ids. + sample_ids (Union[pd.DataFrame, np.ndarray], optional): Sample IDs, if applicable. Defaults to None. If None, will auto generate sample_ids. + match_id_name (str, optional): Column name for match IDs in the resulting DataFrame. If None, Defaults to 'id'. + sample_id_name (str, optional): Column name for sample IDs in the resulting DataFrame. If None, Defaults to 'sample_id'. + dataframe_format (Literal['default', 'fate_std'], optional): Output format of the resulting DataFrame. If 'default', simply combines labels and predictions into a DataFrame. + If 'fate_std', organizes output according to the FATE framework's format. Defaults to 'default'. + task_type (Literal['binary', 'multi', 'regression', 'others'], optional): This parameter is only needed when dataframe_format is 'fate_std'. Defaults to None. + The type of machine learning task, which can be 'binary', 'multi', 'regression', or 'others'. + threshold (float, optional): This parameter is only needed when dataframe_format is 'fate_std' and task_type is 'binary'. Defaults to 0.5. + classes (list, optional): This parameter is only needed when dataframe_format is 'fate_std'. List of classes. + Returns: + DataFrame: A DataFrame that contains the neural network's predictions and the true labels, possibly along with match IDs and sample IDs, formatted according to the specified format. + """ + # check parameters + assert task_type in ['binary', 'multi', 'regression', 'others'], f"task_type {task_type} is not supported" + assert dataframe_format in ['default', 'fate_std'], f"dataframe_format {dataframe_format} is not supported" + + if match_id_name is None: + match_id_name = 'id' + if sample_id_name is None: + sample_id_name = 'sample_id' + + if isinstance(predictions, PredictionOutput): + predictions = predictions.predictions + if isinstance(labels, PredictionOutput): + labels = labels.label_ids + + predictions = _convert_to_numpy_array(predictions) + labels = _convert_to_numpy_array(labels) + assert len(predictions) == len(labels), f"predictions length {len(predictions)} != labels length {len(labels)}" + + # check match ids + if match_ids is not None: + match_ids = _convert_to_numpy_array(match_ids).flatten() else: - raise ValueError('You need to provide either train_eval_prediction or test_eval_prediction') + logger.info("match_ids is not provided, will auto generate match_ids") + match_ids = np.array([i for i in range(len(predictions))]).flatten() - if train_eval_prediction is not None: - results["train"] = get_formatted_output_df(train_eval_prediction, input_data.get_train_ids(), consts.TRAIN_SET, task_type, threshold=threshold) - if validate_eval_prediction is not None: - results["validate"] = get_formatted_output_df(validate_eval_prediction, input_data.get_validate_ids(), consts.VALIDATE_SET, task_type, threshold=threshold) - if test_eval_prediction is not None: - results["test"] = get_formatted_output_df(test_eval_prediction, input_data.get_test_ids(), consts.TEST_SET, task_type, threshold=threshold) + # check sample ids + if sample_ids is not None: + sample_ids = _convert_to_numpy_array(sample_ids).flatten() + else: + logger.info("sample_ids is not provided, will auto generate sample_ids") + sample_ids = np.array([i for i in range(len(predictions))]).flatten() - return NNOutput(train_result=results.get("train"), validate_result=results.get("validate"), test_result=results.get("test"), - match_id_name=match_id_name, sample_id_name=sample_id_name) - - def get_fateboard_tracker(self): - pass + assert len(match_ids) == len(predictions), f"match_ids length {len(match_ids)} != predictions length {len(predictions)}" + assert len(sample_ids) == len(predictions), f"sample_ids length {len(sample_ids)} != predictions length {len(predictions)}" - def train(self, input_data: NNInput = None) -> Optional[Union[NNOutput, None]]: - pass + # match id name and sample id name must be str + assert isinstance(match_id_name, str), f"match_id_name must be str, but got {type(match_id_name)}" + assert isinstance(sample_id_name, str), f"sample_id_name must be str, but got {type(sample_id_name)}" + + if dataframe_format == 'default' or (dataframe_format == 'fate_std' and task_type == 'others'): + df = pd.DataFrame({'label': labels.to_list(), 'predict': predictions.to_list(), match_id_name: match_ids.to_list(), sample_id_name: sample_ids.to_list()}) + df = to_fate_df(ctx, sample_id_name, match_id_name, df) + return df + elif dataframe_format == 'fate_std' and task_type in ['binary', 'multi', 'regression']: + df = std_output_df(task_type, predictions, labels, threshold, classes) + match_id_df = pd.DataFrame() + match_id_df[match_id_name] = match_ids + sample_id_df = pd.DataFrame() + sample_id_df[sample_id_name] = sample_ids + df = add_ids(df, match_id_df, sample_id_df) + df = to_fate_df(ctx, sample_id_name, match_id_name, df) + return df - def predict(self, input_data: NNInput = None) -> Optional[Union[NNOutput, None]]: + + def train(self, train_data: Optional[Union[str, DataFrame]] = None, validate_data: Optional[Union[str, DataFrame]] = None, output_dir: str = None, saved_model_path: str = None) -> None: + """ + Train interface. + + Parameters: + train_data (Union[str, DataFrame]): The training data, which can be a FATE DataFrame containing the data, or a string path representing the bound data.Train data is Optional on the server side. + validate_data (Optional[Union[str, DataFrame]]): The validation data, which can be a FATE DataFrame containing the data, or a string path representing the bound data . This argument is optional. + output_dir (str, optional): The path to the directory where the trained model should be saved. If this class is running in the fate pipeline, this path will provided by FATE framework. + saved_model_path (str, optional): The path to the saved model that should be loaded before training starts.If this class is running in the fate pipeline, this path will provided by FATE framework. + """ pass + def predict(self, test_data: Optional[Union[str, DataFrame]] = None, output_dir: str = None, saved_model_path: str = None) -> DataFrame: + """ + Predict interface. + + Parameters: + test_data (Union[str, DataFrame]): The data to predict, which can be a FATE DataFrame containing the data, or a string path representing the bound data.Test data is Optional on the server side. + output_dir (str, optional): The path to the directory where the trained model should be saved. If this class is running in the fate pipeline, this path will provided by FATE framework. + saved_model_path (str, optional): The path to the saved model that should be loaded before training starts.If this class is running in the fate pipeline, this path will provided by FATE framework. + """ + diff --git a/python/fate/components/components/nn/runner/default_runner.py b/python/fate/components/components/nn/runner/default_runner.py index 37bb27b69d..cf0a34caa8 100644 --- a/python/fate/components/components/nn/runner/default_runner.py +++ b/python/fate/components/components/nn/runner/default_runner.py @@ -1,6 +1,6 @@ import torch as t import os -from fate.components.components.nn.nn_runner import NNInput, NNRunner, NNOutput +from fate.components.components.nn.nn_runner import NNRunner from fate.ml.nn.algo.homo.fedavg import FedAVG, FedAVGArguments, FedAVGCLient, FedAVGServer, TrainingArguments from typing import Optional, Dict, Union from fate.components.components.nn.loader import Loader @@ -15,6 +15,8 @@ from typing import Literal import logging from fate.components.components.utils import consts +from fate.ml.nn.dataset.table import TableDataset +from fate.arch.dataframe import DataFrame logger = logging.getLogger(__name__) @@ -25,7 +27,7 @@ def load_model_dict_from_path(path): # Ensure that the path is a string - assert isinstance(path, str), "Path must be a string" + assert isinstance(path, str), "Path must be a string, but got {}".format(type(path)) # Append the filename to the path model_path = os.path.join(path, 'pytorch_model.bin') @@ -54,9 +56,6 @@ class SetupReturn: def __init__(self, trainer: Union[Type[FedTrainerClient], Type[FedTrainerServer]] = None, model: Type[nn.Module] = None, - train_set: Type[data_utils.Dataset] = None, - validate_set: Type[data_utils.Dataset] = None, - test_set: Type[data_utils.Dataset] = None, optimizer: Type[optim.Optimizer] = None, loss: Callable = None, scheduler: Type[_LRScheduler] = None, @@ -70,15 +69,6 @@ def __init__(self, if model is not None and not issubclass(type(model), nn.Module): raise TypeError(f"SetupReturn Error: model must be a subclass of torch.nn.Module but got {type(model)}") - if train_set is not None and not issubclass(type(train_set), data_utils.Dataset): - raise TypeError(f"SetupReturn Error: train_set must be a subclass of torch.utils.data.Dataset but got {type(train_set)}") - - if validate_set is not None and not issubclass(type(validate_set), data_utils.Dataset): - raise TypeError(f"SetupReturn Error: validate_set must be a subclass of torch.utils.data.Dataset but got {type(validate_set)}") - - if test_set is not None and not issubclass(type(test_set), data_utils.Dataset): - raise TypeError(f"SetupReturn Error: test_set must be a subclass of torch.utils.data.Dataset but got {type(test_set)}") - if optimizer is not None and not issubclass(type(optimizer), optim.Optimizer): raise TypeError(f"SetupReturn Error: optimizer must be a subclass of torch.optim.Optimizer but got {type(optimizer)}") @@ -99,9 +89,6 @@ def __init__(self, self.trainer = trainer self.model = model - self.train_set = train_set - self.validate_set = validate_set - self.test_set = test_set self.optimizer = optimizer self.loss = loss self.scheduler = scheduler @@ -134,7 +121,7 @@ def __init__(self, data_collator_conf: Optional[Dict] = None, tokenizer_conf: Optional[Dict] = None, task_type: Literal['binary', 'multi', 'regression', 'others'] = 'binary', - use_hf_default_behavior: bool = False, + threshold: float = 0.5, local_mode: bool = False ) -> None: @@ -147,10 +134,21 @@ def __init__(self, self.fed_args_conf = fed_args_conf self.loss_conf = loss_conf self.data_collator_conf = data_collator_conf - self.use_hf_default_behavior = use_hf_default_behavior self.local_mode = local_mode self.tokenizer_conf = tokenizer_conf self.task_type = task_type + self.threshold = threshold + + # check param + if self.algo not in SUPPORTED_ALGO: + raise ValueError('algo should be one of [fedavg]') + if self.task_type not in ['binary', 'multi', 'regression', 'others']: + raise ValueError('task_type should be one of [binary, multi, regression, others]') + assert self.threshold >= 0 and self.threshold <= 1, 'threshold should be in [0, 1]' + assert isinstance(self.local_mode, bool), 'local should be bool' + + # setup var + self.trainer = None def _loader_load_from_conf(self, conf, return_class=False): if conf is None: @@ -159,146 +157,147 @@ def _loader_load_from_conf(self, conf, return_class=False): return Loader.from_dict(conf).load_item() return Loader.from_dict(conf).call_item() - def _prepare_dataset(self, dataset_conf, cpn_input_data, schema=None): + def _prepare_data(self, data, data_name) -> SetupReturn: - if cpn_input_data is None: - logger.info('input cpn data is None, return') - return - - if dataset_conf is None: - # Automatically create dataset class - label_name = None - if schema is not None: - label_name = schema.label_name - if label_name is None: - logger.info('schema is provided, but label name is None, TableDataset will automatically infer label') - else: - logger.info('schema is provided, label name is {}'.format(label_name)) - else: - logger.info('schema is not provided') - + if data is None: + return None + if isinstance(data, DataFrame) and self.dataset_conf is None: + logger.info('Input data {} is FATE DataFrame and dataset conf is None, will automatically handle the input data'.format(data_name)) if self.task_type == consts.MULTI: - dataset = TableDataset(label_col=label_name, flatten_label=True, label_dtype='long') + dataset = TableDataset(flatten_label=True, label_dtype='long', to_tensor=True) else: - dataset = TableDataset(label_col=label_name) - logger.info('dataset conf is not set, use default FATE Table Dataset') - - else: - dataset = self._loader_load_from_conf(dataset_conf) - - if hasattr(dataset, 'load'): - dataset.load(cpn_input_data) - return dataset + dataset = TableDataset(to_tensor=True) + dataset.load(data) else: - raise ValueError(f"dataset {dataset} has no load() method") - - def setup(self, cpn_input_data: NNInput, stage='train'): + dataset = self._loader_load_from_conf(self.dataset_conf) + if hasattr(dataset, 'load'): + dataset.load(data) + else: + raise ValueError(f"The dataset {dataset} lacks a load() method, which is required for data parsing in the DefaultRunner. \ + Please implement this method in your dataset class. You can refer to the base class 'Dataset' in 'fate.ml.nn.dataset.base' \ + for the necessary interfaces to implement.") + if dataset is not None and not issubclass(type(dataset), data_utils.Dataset): + raise TypeError(f"SetupReturn Error: {data_name}_set must be a subclass of torch.utils.data.Dataset but got {type(dataset)}") + + return dataset + + def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None, stage='train'): if stage == 'predict': self.local_mode = True if self.algo == 'fedavg': client_class: FedAVGCLient = FedAVG.client - server_class: FedAVGServer = FedAVG.server else: raise ValueError(f"algo {self.algo} not supported") - + ctx = self.get_context() - - if self.is_client(): - - # load arguments, models, etc - # prepare datatset - # dataet - logger.info('NNInput data type is {}'.format(cpn_input_data.input_type)) - train_set = self._prepare_dataset(self.dataset_conf, cpn_input_data.get_train_data(), schema=cpn_input_data.get_schema()) - validate_set = self._prepare_dataset(self.dataset_conf, cpn_input_data.get_validate_data(), schema=cpn_input_data.get_schema()) - test_set = self._prepare_dataset(self.dataset_conf, cpn_input_data.get_test_data(), schema=cpn_input_data.get_schema()) - # load model - model = self._loader_load_from_conf(self.model_conf) - if model is None: - raise ValueError(f"model is None, cannot load model from conf {self.model_conf}") - # save path: path to save provided by fate framework - save_path = cpn_input_data.get_fate_save_path() - # if have input model for warm-start - model_path = cpn_input_data.get_saved_model_path() - # resume_from checkpoint path + model = self._loader_load_from_conf(self.model_conf) + if model is None: + raise ValueError(f"model is None, cannot load model from conf {self.model_conf}") + + if output_dir is None: + output_dir = './' + + if saved_model is not None: + model_dict = load_model_dict_from_path(saved_model) + model.load_state_dict(model_dict) + logger.info(f"loading model dict from {saved_model} to model done") + if get_last_checkpoint(saved_model) is not None: + resume_path = saved_model + logger.info(f"checkpoint detected, resume_path set to {resume_path}") + else: resume_path = None - - if model_path is not None: - model_dict = load_model_dict_from_path(model_path) - model.load_state_dict(model_dict) - logger.info(f"loading model dict from {model_path} to model done") - if get_last_checkpoint(model_path) is not None: - resume_path = model_path - logger.info(f"checkpoint detected, resume_path set to {resume_path}") - - # load optimizer - optimizer_loader = Loader.from_dict(self.optimizer_conf) - optimizer_ = optimizer_loader.load_item() - optimizer_params = optimizer_loader.kwargs - optimizer = optimizer_(model.parameters(), **optimizer_params) - # load loss - loss = self._loader_load_from_conf(self.loss_conf) - # load collator func - data_collator = self._loader_load_from_conf(self.data_collator_conf) - # load tokenizer if import conf provided - tokenizer = self._loader_load_from_conf(self.tokenizer_conf) - # args - dir_warning(self.training_args_conf) - training_args = TrainingArguments(**self.training_args_conf) - training_args.output_dir = save_path # reset to default, saving to arbitrary path is not allowed in NN component - training_args.resume_from_checkpoint = resume_path # resume path - fed_args = FedAVGArguments(**self.fed_args_conf) - - # prepare trainer - trainer = client_class(ctx=ctx, model=model, loss_fn=loss, - optimizer=optimizer, training_args=training_args, - fed_args=fed_args, data_collator=data_collator, - tokenizer=tokenizer, train_set=train_set, val_set=validate_set, local_mode=self.local_mode) - - return SetupReturn(trainer=trainer, model=model, optimizer=optimizer, loss=loss, - train_args=training_args, fed_args=fed_args, data_collator=data_collator, - train_set=train_set, validate_set=validate_set, test_set=test_set) - elif self.is_server(): - trainer = server_class(ctx=ctx, local_mode=self.local_mode) - return SetupReturn(trainer=trainer) + # load optimizer + optimizer_loader = Loader.from_dict(self.optimizer_conf) + optimizer_ = optimizer_loader.load_item() + optimizer_params = optimizer_loader.kwargs + optimizer = optimizer_(model.parameters(), **optimizer_params) + # load loss + loss = self._loader_load_from_conf(self.loss_conf) + # load collator func + data_collator = self._loader_load_from_conf(self.data_collator_conf) + # load tokenizer if import conf provided + tokenizer = self._loader_load_from_conf(self.tokenizer_conf) + # args + dir_warning(self.training_args_conf) + training_args = TrainingArguments(**self.training_args_conf) + training_args.output_dir = output_dir # reset to default, saving to arbitrary path is not allowed in DefaultRunner + training_args.resume_from_checkpoint = resume_path # resume path + fed_args = FedAVGArguments(**self.fed_args_conf) + + # prepare trainer + trainer = client_class(ctx=ctx, model=model, loss_fn=loss, + optimizer=optimizer, training_args=training_args, + fed_args=fed_args, data_collator=data_collator, + tokenizer=tokenizer, train_set=train_set, val_set=validate_set, local_mode=self.local_mode) + + return SetupReturn(trainer=trainer, model=model, optimizer=optimizer, loss=loss, + train_args=training_args, fed_args=fed_args, data_collator=data_collator) + + def server_setup(self, stage='train'): - def train(self, input_data: NNInput = None) -> Optional[Union[NNOutput, None]]: + if stage == 'predict': + self.local_mode = True + if self.algo == 'fedavg': + server_class: FedAVGServer = FedAVG.server + else: + raise ValueError(f"algo {self.algo} not supported") + ctx = self.get_context() + trainer = server_class(ctx=ctx, local_mode=self.local_mode) + return SetupReturn(trainer=trainer) + def train(self, train_data: Optional[Union[str, DataFrame]] = None, validate_data: Optional[Union[str, DataFrame]] = None, output_dir: str = None, saved_model_path: str = None): - setup = self.setup(input_data, stage='train') - trainer = setup['trainer'] if self.is_client(): - + train_set = self._prepare_data(train_data, 'train_data') + validate_set = self._prepare_data(validate_data, 'val_data') + setup = self.client_setup(train_set=train_set, validate_set=validate_set, output_dir=output_dir, saved_model=saved_model_path) + trainer = setup['trainer'] + self.trainer = trainer trainer.train() - trainer.save_model(input_data.get('fate_save_path')) - # predict the dataset when training is done - train_rs = trainer.predict(setup['train_set']) if setup['train_set'] else None - validate_rs = trainer.predict(setup['validate_set']) if setup['validate_set'] else None - - ret = self.generate_std_nn_output(input_data=input_data, - train_eval_prediction=train_rs, - validate_eval_prediction=validate_rs, - task_type=self.task_type, - threshold=0.5) - - logger.debug(f"train output: {ret}") - - return ret - + if output_dir is not None: + trainer.save_model(output_dir) elif self.is_server(): + setup = self.server_setup() + trainer = setup['trainer'] trainer.train() - def predict(self, input_data: NNInput = None) -> Union[NNOutput, None]: + def _run_dataset_func(self, dataset, func_name): - setup = self.setup(input_data, stage='predict') - test_set = setup['test_set'] - trainer = setup['trainer'] - pred_rs = trainer.predict(test_set) - ret = self.generate_std_nn_output(input_data=input_data, test_eval_prediction=pred_rs, task_type=self.task_type, threshold=0.5) - return ret + if hasattr(dataset, func_name): + output = getattr(dataset, func_name)() + if output is None: + logger.info(f'dataset {type(dataset)}: {func_name} returns None, this will influence the output of predict') + return output + else: + logger.info(f'dataset {type(dataset)} not implemented {func_name}, classes set to None, this will influence the output of predict') + return None + + def predict(self, test_data: Union[str, DataFrame], saved_model_path: str = None) -> Union[DataFrame, None]: + + if self.is_client(): + test_set = self._prepare_data(test_data, 'test_data') + if self.trainer is not None: + trainer = self.trainer + logger.info('trainer found, skip setting up') + else: + setup = self.client_setup(saved_model=saved_model_path, stage='predict') + trainer = setup['trainer'] + + classes = self._run_dataset_func(test_set, 'get_classes') + match_ids = self._run_dataset_func(test_set, 'get_match_ids') + sample_ids = self._run_dataset_func(test_set, 'get_sample_ids') + match_id_name = self._run_dataset_func(test_set, 'get_match_id_name') + sample_id_name = self._run_dataset_func(test_set, 'get_sample_id_name') + pred_rs = trainer.predict(test_set) + rs_df = self.get_nn_output_dataframe(self.get_context(), pred_rs.predictions, pred_rs.label_ids, match_ids, sample_ids, match_id_name=match_id_name, sample_id_name=sample_id_name, + dataframe_format='fate_std', task_type=self.task_type, classes=classes) + return rs_df + else: + # server not predict + return diff --git a/python/fate/components/components/nn/test/test_default_runner.py b/python/fate/components/components/nn/test/test_default_runner.py new file mode 100644 index 0000000000..43a5978320 --- /dev/null +++ b/python/fate/components/components/nn/test/test_default_runner.py @@ -0,0 +1,60 @@ +from fate.components.components.nn.runner.default_runner import DefaultRunner +from fate_client.pipeline.components.fate.nn.fate_torch import nn, optim +from fate_client.pipeline.components.fate.nn.fate_torch.base import Sequential +from fate_client.pipeline.components.fate.homo_nn import get_config_of_default_runner +from fate_client.pipeline.components.fate.nn.loader import DatasetLoader +from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments +from fate.arch import Context +from fate.arch.computing.standalone import CSession +from fate.arch.context import Context +from fate.arch.federation.standalone import StandaloneFederation +import pandas as pd +from fate.arch.dataframe import PandasReader +import logging +from fate.components.core import GUEST + +# Get the root logger +logger = logging.getLogger() +logger.setLevel(logging.INFO) +ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +ch.setFormatter(formatter) +logger.addHandler(ch) + + +computing = CSession() +ctx = Context( + "guest", + computing=computing, + federation=StandaloneFederation(computing, "fed", ("guest", 10000), [("host", 9999)]), +) + +df = pd.read_csv('./../../../../../../examples/data/vehicle_scale_homo_guest.csv') +df['sample_id'] = [i for i in range(len(df))] + +reader = PandasReader(sample_id_name='sample_id', match_id_name="id", label_name="y", dtype="object") +data = reader.to_frame(ctx, df) + +runner_conf=get_config_of_default_runner( + algo='fedavg', + model=Sequential( + nn.Linear(18, 10), + nn.ReLU(), + nn.Linear(10 ,4), + nn.Softmax() + ), + loss=nn.CrossEntropyLoss(), + dataset=DatasetLoader('table', 'TableDataset', flatten_label=True, label_dtype='long'), + optimizer=optim.Adam(lr=0.01), + training_args=TrainingArguments(num_train_epochs=50, per_device_train_batch_size=128), + fed_args=FedAVGArguments(), + task_type='binary' + ) + +runner = DefaultRunner(**runner_conf) +runner.set_context(ctx) +runner.set_role(GUEST) +runner.local_mode = True +runner.train(data) +rs = runner.predict(data) diff --git a/python/fate/components/components/utils/predict_format.py b/python/fate/components/components/utils/predict_format.py deleted file mode 100644 index bbafb6c5df..0000000000 --- a/python/fate/components/components/utils/predict_format.py +++ /dev/null @@ -1,91 +0,0 @@ -import numpy as np -from fate.arch.dataframe._dataframe import DataFrame -import pandas as pd -from pandas import DataFrame as pd_DataFrame -from typing import Union -from fate.components.components.utils import consts -from fate.arch.dataframe import PandasReader -from fate.arch.dataframe.manager.schema_manager import Schema -import json - -# variable: -LABEL = "label" -PREDICT_LABEL = "predict_result" -PREDICT_SCORE = "predict_score" -PREDICT_DETAIL = "predict_detail" -TYPE = "type" - - - -def predict_detail_dict_to_str(result_dict): - return "\"" + json.dumps(result_dict).replace("\"", "\'") + "\"" - - -def predict_detail_str_to_dict(result_dict_str): - return json.loads(json.loads(result_dict_str).replace("\'", "\"")) - - -def get_output_pd_df(pred_table: pd_DataFrame, label: Union[pd_DataFrame, pd.Series, np.ndarray], match_id_name, sample_id_name, dataset_type=consts.TRAIN_SET, - task_type=consts.BINARY, classes=None, threshold=0.5): - - df = pred_table - - if match_id_name not in pred_table.columns: - raise ValueError(f"match_id_column {match_id_name} not in predict_table whose columns are {pred_table.columns}") - - if sample_id_name not in pred_table.columns: - raise ValueError(f"sample_id_column {sample_id_name} not in predict_table whose columns are {pred_table.columns}") - - if PREDICT_SCORE not in pred_table.columns: - raise ValueError(f"predict_score not in predict_table whose columns are {pred_table.columns}") - - pred_rs_df = pd.DataFrame() - pred_rs_df[match_id_name] = df[match_id_name] - pred_rs_df[sample_id_name] = df[sample_id_name] - pred_rs_df[PREDICT_SCORE] = df[PREDICT_SCORE] - if task_type == consts.BINARY: - if classes is None: - raise ValueError("classes must be specified positive and negative when task_type is binary, example: [0, 1] as negative, positive") - class_neg, class_pos = classes[0], classes[1] - pred_rs_df[PREDICT_SCORE] = pred_rs_df[PREDICT_SCORE].apply(lambda x: x[0]) - pred_rs_df[PREDICT_LABEL] = df[PREDICT_SCORE].apply(lambda x: class_pos if x[0] >= threshold else class_neg) - pred_rs_df[PREDICT_DETAIL] = df[PREDICT_SCORE].apply(lambda x: predict_detail_dict_to_str({class_pos: x[0], class_neg: 1 - x[0]})) - elif task_type == consts.MULTI: - if classes is None: - raise ValueError("classes must be specified when task_type is multi") - pred_rs_df[PREDICT_LABEL] = df[PREDICT_SCORE].apply(lambda x: classes[x.index(max(x))]) - pred_rs_df[PREDICT_DETAIL] = df[PREDICT_SCORE].apply(lambda x: predict_detail_dict_to_str({classes[i]: x[i] for i in range(len(x))})) - elif task_type == consts.REGRESSION: - pred_rs_df[PREDICT_SCORE] = pred_rs_df[PREDICT_SCORE].apply(lambda x: x[0]) - pred_rs_df[PREDICT_LABEL] = pred_rs_df[PREDICT_SCORE] - pred_rs_df[PREDICT_DETAIL] = df[PREDICT_SCORE].apply(lambda x: predict_detail_dict_to_str({"label": x[0]})) - else: - raise ValueError(f"task_type {task_type} is not supported") - - pred_rs_df[TYPE] = dataset_type - pred_rs_df[LABEL] = label - - return pred_rs_df - - -def predict_score_to_output(ctx, pred_table:DataFrame, train_data: DataFrame, dataset_type=consts.TRAIN_SET, - task_type=consts.BINARY, classes=None, threshold=0.5) -> DataFrame: - - assert dataset_type in [consts.TRAIN_SET, consts.TEST_SET, consts.VALIDATION_SET], f"dataset_type {dataset_type} is not supported" - - if isinstance(pred_table, DataFrame): - df: pd_DataFrame = pred_table.as_pd_df() - else: - raise TypeError(f"predict_table type {type(pred_table)} is not supported") - - schema = train_data.schema - label_name = schema.label_name - label_df = train_data[label_name].as_pd_df() - label = label_df[label_name] - match_id_name = schema.match_id_name - sample_id_name = schema.sample_id_name - pred_rs_df = get_output_pd_df(df, label, match_id_name, sample_id_name, dataset_type, task_type, classes, threshold) - reader = PandasReader(sample_id_name=sample_id_name, match_id_name=match_id_name, label_name=LABEL, dtype="object") - output_df = reader.to_frame(ctx, pred_rs_df) - - return output_df diff --git a/python/fate/components/components/utils/tools.py b/python/fate/components/components/utils/tools.py index 650f330001..d1e9fb92de 100644 --- a/python/fate/components/components/utils/tools.py +++ b/python/fate/components/components/utils/tools.py @@ -1,5 +1,8 @@ from fate.arch.dataframe import DataFrame -from .consts import TRAIN_SET, VALIDATE_SET, TESET_SET +from .consts import TRAIN_SET, VALIDATE_SET, TEST_SET + + +TYPE = 'type' def cat_train_and_validate_df(train_df: DataFrame, val_df: DataFrame): @@ -10,7 +13,8 @@ def cat_train_and_validate_df(train_df: DataFrame, val_df: DataFrame): def add_dataset_type(df: DataFrame, dataset_type): - assert dataset_type in [TRAIN_SET, VALIDATE_SET, TESET_SET], f"dataset_type must be one of {TRAIN_SET}, {VALIDATE_SET}, {TESET_SET}" + assert dataset_type in [TRAIN_SET, VALIDATE_SET, TEST_SET], f"dataset_type must be one of {TRAIN_SET}, {VALIDATE_SET}, {TEST_SET}" + df[TYPE] = dataset_type return df diff --git a/python/fate/ml/evaluation/metric_base.py b/python/fate/ml/evaluation/metric_base.py index 29dbeee2c3..20766bcb8b 100644 --- a/python/fate/ml/evaluation/metric_base.py +++ b/python/fate/ml/evaluation/metric_base.py @@ -111,14 +111,13 @@ def _parse_input(self, eval_rs): def __call__(self, eval_rs=None, predict=None, label=None, **kwargs) -> Dict: - metric_result = {} + metric_result = [] if eval_rs is not None: predict, label, input_ = self._parse_input(eval_rs) for metric in self._metrics: rs = metric(predict, label) - logger.info('metric: {}, result: {}'.format(metric.metric_name, rs)) if isinstance(rs, tuple): new_rs = [r.to_dict() for r in rs] rs = new_rs @@ -126,7 +125,7 @@ def __call__(self, eval_rs=None, predict=None, label=None, **kwargs) -> Dict: rs = rs.to_dict() else: raise ValueError('cannot parse metric result: {}'.format(rs)) - metric_result[metric.metric_name] = rs + metric_result.append(rs) return metric_result def fit(self, eval_rs=None, predict=None, label=None, **kwargs) -> Dict: diff --git a/python/fate/ml/glm/__init__.py b/python/fate/ml/glm/__init__.py index f508993cb8..4495033129 100644 --- a/python/fate/ml/glm/__init__.py +++ b/python/fate/ml/glm/__init__.py @@ -1,2 +1,4 @@ from .hetero.coordinated_linr import CoordinatedLinRModuleHost, CoordinatedLinRModuleGuest, CoordinatedLinRModuleArbiter from .hetero.coordinated_lr import CoordinatedLRModuleHost, CoordinatedLRModuleGuest, CoordinatedLRModuleArbiter +from .homo.lr.client import HomoLRClient +from .homo.lr.server import HomoLRServer diff --git a/python/fate/ml/glm/homo_lr/client.py b/python/fate/ml/glm/homo/lr/client.py similarity index 68% rename from python/fate/ml/glm/homo_lr/client.py rename to python/fate/ml/glm/homo/lr/client.py index 5920af2df6..daa8d45a67 100644 --- a/python/fate/ml/glm/homo_lr/client.py +++ b/python/fate/ml/glm/homo/lr/client.py @@ -5,58 +5,20 @@ from fate.ml.utils.model_io import ModelIO from fate.arch import Context import logging -import pandas as pd import torch as t from fate.ml.nn.algo.homo.fedavg import FedAVGCLient, TrainingArguments, FedAVGArguments from transformers import default_data_collator -import numpy as np -from torch.nn import functional as F import functools import tempfile -from torch.utils.data import Dataset -from fate.ml.utils.predict_format import std_output_df, add_ids, to_fate_df -from fate.ml.utils.predict_format import MULTI, BINARY +from fate.ml.utils.predict_tools import std_output_df, add_ids, to_fate_df +from fate.ml.utils.predict_tools import MULTI, BINARY +from fate.ml.nn.dataset.table import TableDataset +from fate.ml.utils._optimizer import optimizer_factory, lr_scheduler_factory logger = logging.getLogger(__name__) -class Data(object): - - def __init__(self, features: pd.DataFrame, sample_ids: pd.DataFrame, match_ids: pd.DataFrame, labels: pd.DataFrame) -> None: - # set var - self.features = features - self.sample_ids = sample_ids - self.match_ids = match_ids - self.labels = labels - - def get_match_id_name(self): - return self.match_ids.columns[0] - - def get_sample_id_name(self): - return self.sample_ids.columns[0] - - def has_label(self): - return self.labels is not None - - @staticmethod - def from_fate_dataframe(df: DataFrame): - schema = df.schema - sample_id = schema.sample_id_name - match_id = schema.match_id_name - label = schema.label_name - pd_df = df.as_pd_df() - if label is None: - labels = None - features = pd_df.drop([sample_id, match_id], axis=1) - else: - labels = pd_df[[label]] - features = pd_df.drop([sample_id, match_id, label], axis=1) - sample_ids = pd_df[[sample_id]] - match_ids = pd_df[[match_id]] - - return Data(features, sample_ids, match_ids, labels) - def homo_lr_loss(pred, labels, dim=1): """ @@ -66,10 +28,9 @@ def homo_lr_loss(pred, labels, dim=1): # initialize the loss loss = 0.0 - if dim == 2: - dim -= 1 - loss_fn = t.nn.BCELoss() + if dim <= 2: + return loss_fn(pred[:, 0], labels) for c in range(dim): # get binary labels for this class @@ -85,7 +46,7 @@ def homo_lr_loss(pred, labels, dim=1): class HomoLRModel(t.nn.Module): - def __init__(self, feature_num, label_num=2, l1=0) -> None: + def __init__(self, feature_num, label_num=2, l1=0, bias=True) -> None: super().__init__() assert feature_num >= 2 and isinstance(feature_num, int), "feature_num must be int greater than 2" assert label_num >= 1 and isinstance(label_num, int), "label_num must be int greater than 1" @@ -93,13 +54,13 @@ def __init__(self, feature_num, label_num=2, l1=0) -> None: if 2 >= label_num > 0: self.models.append( - t.nn.Linear(feature_num, 1) + t.nn.Linear(feature_num, 1, bias=bias) ) else: # OVR Setting for i in range(label_num): self.models.append( - t.nn.Linear(feature_num, 1) + t.nn.Linear(feature_num, 1, bias=bias) ) self.sigmoid = t.nn.Sigmoid() self.softmax = t.nn.Softmax(dim=1) @@ -148,13 +109,13 @@ def from_dict(cls, model_dict): return model -def init_model(model, method='random', val=1.0): +def init_model(model, method='random', fill_val=1.0): if method == 'zeros': init_fn = nn.init.zeros_ elif method == 'ones': init_fn = nn.init.ones_ elif method == 'consts': - init_fn = lambda x: nn.init.constant_(x, val) + init_fn = lambda x: nn.init.constant_(x, fill_val) elif method == 'random': init_fn = nn.init.normal_ else: @@ -188,33 +149,32 @@ def get_torch_bytes(model_dict): return model_saved_bytes -class DictDataset(Dataset): - """TensorDataset with support of transforms. - """ - def __init__(self, data): - self.X = np.array(data.features.values).astype(np.float32) - self.X_tensor = t.tensor(self.X, dtype=t.float32) - if data.labels is None: - self.y = None - else: - self.y = np.array(data.labels.values).astype(np.float32) - self.y_tensor = t.tensor(self.y.reshape((-1, 1)), dtype=t.float32) - - def __getitem__(self, index): - if self.y is not None: - return {'x': self.X_tensor[index], 'label': self.y_tensor[index]} - else: - return {'x': self.X_tensor[index]} +def update_params(new_params, default, name='optimizer'): + import copy + params = copy.deepcopy(default) + if not isinstance(new_params, dict): + raise ValueError("{} param dict must be a dict but got {}".format(name, new_params)) + def _update(default, new): + for key in new.keys(): + if key in default: + default[key] = new[key] + + _update(params, new_params) + + return params + + +DEFAULT_OPT_PARAM = {'method': 'sgd', 'penalty': 'l2', 'alpha': 0.0, 'optimizer_params': {'lr': 0.01, 'weight_decay': 0}} +DEFAULT_INIT_PARAM = {"method": "random", "fill_val": 1.0, "fit_intercept": True} +DEFAULT_LR_SCHEDULER_PARAM = {'method': 'constant', 'scheduler_params': {'factor': 1.0}} - def __len__(self): - return self.X_tensor.shape[0] - class HomoLRClient(HomoModule): - def __init__(self, epochs: int=5, batch_size: int=32, optimizer_param=None, - learning_rate_scheduler=None, - init_param=None, + def __init__(self, epochs: int=5, batch_size: int=32, + optimizer_param={'method': 'sgd', 'optimizer_params': {'lr': 0.01, 'weight_decay': 0}}, + learning_rate_scheduler={'method': 'constant', 'scheduler_params': {'factor': 1.0}}, + init_param={"method": "random", "fill_val": 1.0, "fit_intercept": True}, threshold: float=0.5, ovr=False, label_num=None, @@ -222,16 +182,16 @@ def __init__(self, epochs: int=5, batch_size: int=32, optimizer_param=None, super().__init__() self.df_schema = None - self.train_data = None - self.validate_data = None - self.predict_data = None + self.train_set = None + self.validate_set = None + self.predict_set = None # set vars self.max_iter = epochs self.batch_size = batch_size - self.optimizer_param = optimizer_param - self.learning_rate_param = learning_rate_scheduler - self.init_param = init_param + self.optimizer_param = update_params(optimizer_param, DEFAULT_OPT_PARAM, name='optimizer') + self.learning_rate_param = update_params(learning_rate_scheduler, DEFAULT_LR_SCHEDULER_PARAM, name='learning_rate_scheduler') + self.init_param = update_params(init_param, DEFAULT_INIT_PARAM, name='init_param') self.threshold = threshold self.run_ovr = False self.train_feature_num = None @@ -253,25 +213,31 @@ def __init__(self, epochs: int=5, batch_size: int=32, optimizer_param=None, # loaded meta self.loaded_meta = None - # l1 & l2 + # reg self.l1 = 0 self.l2 = 0 + # for testing + self.local_mode = False + # checkping param assert self.max_iter > 0 and isinstance(self.max_iter, int), "max_iter must be int greater than 0" - assert self.batch_size > 0 and isinstance(self.batch_size, int), "batch_size must be int greater than 0" + if self.batch_size != -1: + assert self.batch_size > 0 and isinstance(self.batch_size, int), "batch_size must be int greater than 0 or -1" assert self.threshold > 0 and self.threshold < 1, "threshold must be float between 0 and 1" - def _make_dataset(self, data: Data): - return DictDataset(data) + def _make_dataset(self, data) -> TableDataset: + ds = TableDataset(return_dict=True, to_tensor=True) + ds.load(data) + return ds - def _make_output_df(self, predict_rs, data: Data, threshold: float): + def _make_output_df(self, predict_rs, data: TableDataset, threshold: float): classes = [i for i in range(len(self.model.models))] if len(classes) == 1: # binary: classes = [0, 1] task_type = BINARY if len(classes) == 2 else MULTI out_df = std_output_df(task_type, predict_rs.predictions, predict_rs.label_ids, threshold=threshold, classes=classes) - out_df = add_ids(out_df, data.match_ids, data.sample_ids) + out_df = add_ids(out_df, data.get_match_ids(), data.get_sample_ids()) return out_df def _check_labels(self, label_set, has_validate=False): @@ -297,52 +263,50 @@ def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = No if validate_data is not None: assert isinstance(validate_data, DataFrame), "validate_data must be a fate DataFrame" - self.train_data: Data = Data.from_fate_dataframe(train_data) - if not self.train_data.has_label(): + self.train_set = self._make_dataset(train_data) + if not self.train_set.has_label(): raise RuntimeError("train data must have label column") - self.train_feature_num = self.train_data.features.values.shape[1] - unique_label_set = set(self.train_data.labels.values.reshape(-1)) + self.train_feature_num = self.train_set.features.shape[1] + unique_label_set = set(self.train_set.get_classes()) if validate_data is not None: - self.validate_data = Data.from_fate_dataframe(validate_data) - if not self.validate_data.has_label(): + self.validate_set = self._make_dataset(validate_data) + if not self.validate_set.has_label(): raise RuntimeError("validate data must have label column") - self.validate_feature_num = self.validate_data.features.values.shape[1] + self.validate_feature_num = self.validate_set.features.shape[1] assert self.train_feature_num == self.validate_feature_num, "train and validate feature num not match: {} vs {}".format(self.train_feature_num, self.validate_feature_num) - unique_label_set = unique_label_set.union(set(self.validate_data.labels.values.reshape(-1))) + unique_label_set = unique_label_set.union(set(self.validate_set.get_classes())) self._check_labels(unique_label_set, validate_data is not None) - if validate_data is not None: - unique_label_set = unique_label_set.union(set(self.validate_data.labels.values.reshape(-1))) - logger.info("unique label set updated to: {}".format(unique_label_set)) - - train_set = self._make_dataset(self.train_data) - - if self.validate_data is not None: - validate_set = self._make_dataset(self.validate_data) - else: - validate_set = None + if self.batch_size == -1: + self.batch_size = len(self.train_set) # prepare loss function loss_fn = functools.partial(homo_lr_loss, dim=len(unique_label_set)) + optimizer_params = self.optimizer_param['optimizer_params'] + opt_method = self.optimizer_param['method'] + if self.optimizer_param['penalty'] == 'l2': + self.l2 = self.optimizer_param['alpha'] + optimizer_params['weight_decay'] = self.l2 + elif self.optimizer_param['penalty'] == 'l1': + self.l1 = self.optimizer_param['alpha'] # initialize model if self.model is None: - - self.model = HomoLRModel(self.train_feature_num, label_num=len(unique_label_set), l1=self.l1) - + fit_intercept = self.init_param["fit_intercept"] + self.model = HomoLRModel(self.train_feature_num, label_num=len(unique_label_set), l1=self.l1, bias=fit_intercept) # init model here - init_model(self.model) - + init_model(self.model, method=self.init_param["method"], fill_val=self.init_param["fill_val"]) logger.info('model initialized') logger.info('model parameters are {}'.format(list(self.model.parameters()))) else: logger.info('model is loaded, warm start training') logger.info('model structure is {}'.format(self.model)) - # initialize optimizer - self.optimizer = t.optim.SGD(self.model.parameters(), lr=self.learning_rate_param, weight_decay=self.l2) + self.optimizer = optimizer_factory(self.model.parameters(), opt_method, optimizer_params) + self.lr_scheduler = lr_scheduler_factory(self.optimizer, self.learning_rate_param['method'], self.learning_rate_param['scheduler_params']) + if self.optimizer_state_dict is not None: optimizer_state_dict = { "state": {k: t.tensor(v) for k, v in self.optimizer_state_dict['state'].items()}, @@ -354,9 +318,11 @@ def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = No # training fed_arg = FedAVGArguments() train_arg = TrainingArguments(num_train_epochs=self.max_iter, - per_device_train_batch_size=self.batch_size, per_gpu_eval_batch_size=self.batch_size) - self.trainer = FedAVGCLient(ctx, model=self.model, loss_fn=loss_fn, optimizer=self.optimizer, train_set=train_set, - val_set=validate_set, training_args=train_arg, fed_args=fed_arg, data_collator=default_data_collator) + per_device_train_batch_size=self.batch_size, per_device_eval_batch_size=self.batch_size) + self.trainer = FedAVGCLient(ctx, model=self.model, loss_fn=loss_fn, optimizer=self.optimizer, train_set=self.train_set, + val_set=self.validate_set, training_args=train_arg, fed_args=fed_arg, data_collator=default_data_collator, scheduler=self.lr_scheduler) + if self.local_mode: # for debugging + self.trainer.set_local_mode() self.trainer.train() logger.info('training finished') @@ -365,18 +331,20 @@ def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: if self.model is None: raise ValueError("model is not initialized") - self.predict_data = Data.from_fate_dataframe(predict_data) - predict_set = self._make_dataset(self.predict_data) + self.predict_set = self._make_dataset(predict_data) if self.trainer is None: - train_arg = TrainingArguments(num_train_epochs=self.max_iter, per_device_eval_batch_size=self.batch_size) - trainer = FedAVGCLient(ctx, train_set=predict_set, model=self.model, training_args=train_arg, + batch_size = len(self.predict_set) if self.batch_size == -1 else self.batch_size + train_arg = TrainingArguments(num_train_epochs=self.max_iter, per_device_eval_batch_size=batch_size) + trainer = FedAVGCLient(ctx, train_set=self.predict_set, model=self.model, training_args=train_arg, fed_args=FedAVGArguments(), data_collator=default_data_collator) trainer.set_local_mode() else: trainer = self.trainer - predict_rs = trainer.predict(predict_set) - predict_out_df = self._make_output_df(predict_rs, self.predict_data, self.threshold) - return to_fate_df(ctx, self.predict_data.get_sample_id_name(), self.predict_data.get_match_id_name(), predict_out_df) + predict_rs = trainer.predict(self.predict_set) + predict_out_df = self._make_output_df(predict_rs, self.predict_set, self.threshold) + match_id_name = self.predict_set.get_match_ids().columns[0] + sample_id_name = self.predict_set.get_sample_ids().columns[0] + return to_fate_df(ctx, match_id_name, sample_id_name, predict_out_df) def get_model(self) -> ModelIO: param = {} diff --git a/python/fate/ml/glm/homo_lr/server.py b/python/fate/ml/glm/homo/lr/server.py similarity index 100% rename from python/fate/ml/glm/homo_lr/server.py rename to python/fate/ml/glm/homo/lr/server.py diff --git a/python/fate/ml/glm/homo_lr/test/local_test.py b/python/fate/ml/glm/homo/lr/test/local_test.py similarity index 65% rename from python/fate/ml/glm/homo_lr/test/local_test.py rename to python/fate/ml/glm/homo/lr/test/local_test.py index d1661f7d22..6582d5cbe2 100644 --- a/python/fate/ml/glm/homo_lr/test/local_test.py +++ b/python/fate/ml/glm/homo/lr/test/local_test.py @@ -4,9 +4,8 @@ from fate.arch.federation.standalone import StandaloneFederation import pandas as pd from fate.arch.dataframe import PandasReader -from fate.ml.glm.homo_lr.client import HomoLRClient, HomoLRModel -import logging - +from fate.ml.nn.dataset.table import TableDataset +from fate.ml.glm.homo.lr.client import HomoLRClient import logging # Get the root logger @@ -26,20 +25,26 @@ federation=StandaloneFederation(computing, "fed", ("guest", 10000), [("host", 9999)]), ) -df = pd.read_csv('./../../../../../../examples/data/breast_homo_guest.csv') +df = pd.read_csv('/home/cwj/FATE/FATE-2.0-pure/FATE/examples/data/breast_homo_guest.csv') df['sample_id'] = [i for i in range(len(df))] reader = PandasReader(sample_id_name='sample_id', match_id_name="id", label_name="y", dtype="object") reader_2 = PandasReader(sample_id_name='sample_id', match_id_name="id", dtype="object") data = reader.to_frame(ctx, df) -df = data.as_pd_df() + +# df = data.as_pd_df() data_2 = reader_2.to_frame(ctx, df.drop(columns=['y'])) +ds = TableDataset(return_dict=True, to_tensor=True) +ds.load(data) -client = HomoLRClient(50, 800, learning_rate_scheduler=0.01) +client = HomoLRClient(50, 800, optimizer_param={'method': 'adam', 'penalty': 'l1', 'aplha':0.1, 'optimizer_para': {'lr': 0.1}}, init_param={'method': 'random', 'fill_val': 1.0}, + learning_rate_scheduler={'method': 'linear', 'scheduler_params': {'start_factor'}} +) client.l2 = 0.01 client.l1 = 0.01 +client.local_mode = True client.fit(ctx, data, validate_data=data) export_model = client.get_model() pred = client.predict(ctx, data) -pred_2 = client.predict(ctx, data_2) +# pred_2 = client.predict(ctx, data_2) diff --git a/python/fate/ml/nn/dataset/base.py b/python/fate/ml/nn/dataset/base.py index 97951450d4..7fd161d91d 100644 --- a/python/fate/ml/nn/dataset/base.py +++ b/python/fate/ml/nn/dataset/base.py @@ -1,94 +1,16 @@ from torch.utils.data import Dataset as Dataset_ -from federatedml.nn.backend.utils.common import ML_PATH, LLM_PATH -import importlib import abc -import numpy as np +import pandas as pd class Dataset(Dataset_): def __init__(self, **kwargs): super(Dataset, self).__init__() - self._type = 'local' # train/predict - self._check = False - self._generated_ids = None - self.training = True - - @property - def dataset_type(self): - if not hasattr(self, '_type'): - raise AttributeError( - 'type variable not exists, call __init__ of super class') - return self._type - - @dataset_type.setter - def dataset_type(self, val): - self._type = val - - def has_dataset_type(self): - return self.dataset_type - - def set_type(self, _type): - self.dataset_type = _type - - def get_type(self): - return self.dataset_type - - def has_sample_ids(self): - - # if not implement get_sample_ids, return False - try: - sample_ids = self.get_sample_ids() - except NotImplementedError as e: - return False - except BaseException as e: - raise e - - if sample_ids is None: - return False - else: - if not self._check: - assert isinstance( - sample_ids, list), 'get_sample_ids() must return a list contains str or integer' - for id_ in sample_ids: - if (not isinstance(id_, str)) and (not isinstance(id_, int)): - raise RuntimeError( - 'get_sample_ids() must return a list contains str or integer: got id of type {}:{}'.format( - id_, type(id_))) - assert len(sample_ids) == len( - self), 'sample id len:{} != dataset length:{}'.format(len(sample_ids), len(self)) - self._check = True - return True - - def init_sid_and_getfunc(self, prefix: str = None): - if prefix is not None: - assert isinstance( - prefix, str), 'prefix must be a str, but got {}'.format(prefix) - else: - prefix = self._type - generated_ids = [] - for i in range(0, self.__len__()): - generated_ids.append(prefix + '_' + str(i)) - self._generated_ids = generated_ids - - def get_func(): - return self._generated_ids - self.get_sample_ids = get_func - - """ - Functions for users - """ - - def train(self, ): - self.training = True - - def eval(self, ): - self.training = False # Function to implemented - @abc.abstractmethod - def load(self, file_path): + def load(self, data_or_path): raise NotImplementedError( 'You must implement load function so that Client can pass file-path to this ' 'class') @@ -98,85 +20,21 @@ def __getitem__(self, item): def __len__(self): raise NotImplementedError() + + def has_label(self) -> bool: + pass - def get_classes(self): - raise NotImplementedError() - - def get_sample_ids(self): - raise NotImplementedError() - - -class ShuffleWrapDataset(Dataset_): - - def __init__(self, dataset: Dataset, shuffle_seed=100): - super(ShuffleWrapDataset, self).__init__() - self.ds = dataset - ids = self.ds.get_sample_ids() - sort_idx = np.argsort(np.array(ids)) - assert isinstance(dataset, Dataset) - self.idx = sort_idx - if shuffle_seed is not None: - np.random.seed(shuffle_seed) - self.shuffled_idx = np.copy(self.idx) - np.random.shuffle(self.shuffled_idx) - else: - self.shuffled_idx = np.copy(self.idx) - self.idx_map = {k: v for k, v in zip(self.idx, self.shuffled_idx)} - - def train(self, ): - self.ds.train() - - def eval(self, ): - self.ds.eval() - - def __getitem__(self, item): - return self.ds[self.idx_map[self.idx[item]]] - - def __len__(self): - return len(self.ds) - - def __repr__(self): - return self.ds.__repr__() - - def has_sample_ids(self): - return self.ds.has_sample_ids() - - def set_shuffled_idx(self, idx_map: dict): - self.shuffled_idx = np.array(list(idx_map.values())) - self.idx_map = idx_map - - def get_sample_ids(self): - ids = self.ds.get_sample_ids() - return np.array(ids)[self.shuffled_idx].tolist() - - def get_classes(self): - return self.ds.get_classes() + def get_classes(self) -> list: + pass + def get_match_ids(self) -> pd.DataFrame: + pass + + def get_sample_ids(self) -> pd.DataFrame: + pass -def get_dataset_class(dataset_module_name: str): + def get_sample_id_name(self) -> str: + pass - if dataset_module_name.endswith('.py'): - dataset_module_name = dataset_module_name.replace('.py', '') - try: - ds_modules = importlib.import_module( - '{}.dataset.{}'.format( - ML_PATH, dataset_module_name) - ) - except BaseException: - ds_modules = importlib.import_module( - '{}.dataset.{}'.format( - LLM_PATH, dataset_module_name) - ) - try: - ds = [] - for k, v in ds_modules.__dict__.items(): - if isinstance(v, type): - if issubclass(v, Dataset) and v is not Dataset: - ds.append(v) - if len(ds) == 0: - raise ValueError('Did not find any class in {}.py that is the subclass of Dataset class'. - format(dataset_module_name)) - else: - return ds[-1] # return the last defined class - except ValueError as e: - raise e + def get_match_id_name(self) -> str: + pass \ No newline at end of file diff --git a/python/fate/ml/nn/dataset/table.py b/python/fate/ml/nn/dataset/table.py index b15072134e..0414570349 100644 --- a/python/fate/ml/nn/dataset/table.py +++ b/python/fate/ml/nn/dataset/table.py @@ -1,6 +1,12 @@ import numpy as np import pandas as pd -from torch.utils.data import Dataset +from fate.arch.dataframe import DataFrame +from fate.ml.nn.dataset.base import Dataset +import logging +import torch as t + + +logger = logging.getLogger(__name__) class TableDataset(Dataset): @@ -11,26 +17,34 @@ class TableDataset(Dataset): Parameters ---------- label_col str, name of label column in csv, if None, will automatically take 'y' or 'label' or 'target' as label - feature_dtype dtype of feature, supports int, long, float, double - label_dtype: dtype of label, supports int, long, float, double - label_shape: list or tuple, the shape of label - flatten_label: bool, flatten extracted label column or not, default is False + match_id_col str, name of match id column in csv, if None, will automatically take 'id' or 'sid' as match id + sample_id_col str, name of sample id column in csv, if None, will automatically generate sample id + feature_dtype str, dtype of features, available: 'long', 'int', 'float', 'double' + label_dtype str, dtype of label, available: 'long', 'int', 'float', 'double' + label_shape tuple or list, shape of label, if None, will automatically infer from data + flatten_label bool, whether to flatten label, if True, will flatten label to 1-d array + to_tensor bool, whether to transform data to pytorch tensor, if True, will transform data to tensor + return_dict bool, whether to return a dict in the format of {'x': xxx, 'label': xxx} if True, will return a dict, else will return a tuple """ def __init__( - self, label_col=None, feature_dtype="float", label_dtype="float", label_shape=None, flatten_label=False + self, label_col=None, match_id_col=None, sample_id_col=None, + feature_dtype="float", label_dtype="float", label_shape=None, flatten_label=False, + to_tensor=True, return_dict=False ): super(TableDataset, self).__init__() - self.with_label = True - self.with_sample_weight = False self.features: np.ndarray = None self.label: np.ndarray = None self.sample_weights: np.ndarray = None self.origin_table: pd.DataFrame = pd.DataFrame() self.label_col = label_col + self.match_id_col = match_id_col + self.sample_id_col = sample_id_col self.f_dtype = self.check_dtype(feature_dtype) self.l_dtype = self.check_dtype(label_dtype) + self.to_tensor = to_tensor + self.return_dict = return_dict if label_shape is not None: assert isinstance(label_shape, tuple) or isinstance(label_shape, list), "label shape is {}".format( label_shape @@ -38,7 +52,7 @@ def __init__( self.label_shape = label_shape self.flatten_label = flatten_label - # ids, match ids is for FATE match id system + # sample ids, match ids self.sample_ids = None self.match_ids = None @@ -65,90 +79,104 @@ def check_dtype(dtype): def __getitem__(self, item): - if self.with_label: - if self.with_sample_weight and self.training: - return self.features[item], (self.label[item], self.sample_weights[item]) + if self.label is not None: + feat = self.features[item] + label = self.label[item] + if self.to_tensor: + feat = t.tensor(feat) + label = t.tensor(label) + if self.return_dict: + return {"x": feat, "label": label} else: - return self.features[item], self.label[item] + return feat, label else: - return self.features[item] + feat = self.features[item] + if self.to_tensor: + feat = t.tensor(feat) + if self.return_dict: + return {"x": feat} + else: + return feat def __len__(self): - return len(self.origin_table) + return len(self.features) - def load(self, file_path): + def load(self, data_or_path): - if isinstance(file_path, str): - self.origin_table = pd.read_csv(file_path) - elif isinstance(file_path, pd.DataFrame): - self.origin_table = file_path - else: + if isinstance(data_or_path, str): + self.origin_table = pd.read_csv(data_or_path) # if is FATE DTable, collect data and transform to array format - data_inst = file_path - self.with_sample_weight = None - print("collecting FATE DTable, with sample weight is {}".format(self.with_sample_weight)) - header = data_inst.scheme["header"] - print("input dtable header is {}".format(header)) - data = list(data_inst.collect()) - data_keys = [key for (key, val) in data] - data_keys_map = dict(zip(sorted(data_keys), range(len(data_keys)))) - - keys = [None for idx in range(len(data_keys))] - x_ = [None for idx in range(len(data_keys))] - y_ = [None for idx in range(len(data_keys))] - match_ids = {} - sample_weights = [1 for idx in range(len(data_keys))] - - for (key, inst) in data: - idx = data_keys_map[key] - keys[idx] = key - x_[idx] = inst.features - y_[idx] = inst.label - match_ids[key] = inst.inst_id - if self.with_sample_weight: - sample_weights[idx] = inst.weight - - x_ = np.asarray(x_) - y_ = np.asarray(y_) - df = pd.DataFrame(x_) - df.columns = header - df["id"] = sorted(data_keys) - df["label"] = y_ - # host data has no label, so this columns will all be None - if df["label"].isna().all(): - df = df.drop(columns=["label"]) - - self.origin_table = df - self.sample_weights = np.array(sample_weights) - self.match_ids = match_ids + label_col_candidates = ["y", "label", "target"] + # automatically set id columns + if self.match_id_col is not None: + if self.match_id_col not in self.origin_table: + raise ValueError("match id column {} not found".format(self.match_id_col)) + else: + self.match_ids = self.origin_table[[self.match_id_col]] + self.origin_table = self.origin_table.drop(columns=[self.match_id_col]) + else: + match_id_col_cadidaites = ["id", "sid"] + for id_col in match_id_col_cadidaites: + if id_col in self.origin_table: + self.match_ids = self.origin_table[[id_col]] + self.origin_table = self.origin_table.drop(columns=[id_col]) + break + if self.match_ids is None: + logger.info("match id column not found, no match id will be set") + + # generate sample ids + if self.sample_id_col is not None: + if self.sample_id_col not in self.origin_table: + raise ValueError("sample id column {} not found".format(self.sample_id_col)) + self.sample_ids = self.origin_table[[self.sample_id_col]] + self.origin_table = self.origin_table.drop(columns=[self.sample_id_col]) + else: + self.sample_ids = pd.DataFrame() + self.sample_ids["sample_id"] = range(len(self.origin_table)) + logger.info("sample id column not found, generate sample id from 0 to {}".format(len(self.origin_table))) - label_col_candidates = ["y", "label", "target"] - - # automatically set id columns - id_col_candidates = ["id", "sid"] - for id_col in id_col_candidates: - if id_col in self.origin_table: - self.sample_ids = self.origin_table[id_col].values.tolist() - self.origin_table = self.origin_table.drop(columns=[id_col]) - break - - # infer column name - label = self.label_col - if label is None: - for i in label_col_candidates: - if i in self.origin_table: - label = i - break + # infer column name + label = self.label_col if label is None: - self.with_label = False - print('label default setting is "auto", but found no "y"/"label"/"target" in input table') - else: - if label not in self.origin_table: - raise ValueError("label column {} not found in input table".format(label)) + for i in label_col_candidates: + if i in self.origin_table: + label = i + break + if label is None: + self.with_label = False + logger.info('found no "y"/"label"/"target" in input table, no label will be set') + else: + if label not in self.origin_table: + raise ValueError("label column {} not found in input table".format(label)) + + if self.label is not None: + self.label = self.origin_table[[label]].values + self.origin_table = self.origin_table.drop(columns=[label]) + self.features = self.origin_table.values + + elif isinstance(data_or_path, DataFrame): + schema = data_or_path.schema + sample_id = schema.sample_id_name + match_id = schema.match_id_name + label = schema.label_name + if label is None: + logger.info("label column is None, not provided in the uploaded data") + pd_df = data_or_path.as_pd_df() + if label is None: + labels = None + features = pd_df.drop([sample_id, match_id], axis=1) + else: + labels = pd_df[[label]] + features = pd_df.drop([sample_id, match_id, label], axis=1) + self.label = labels.values + sample_ids = pd_df[[sample_id]] + match_ids = pd_df[[match_id]] + self.sample_ids = sample_ids + self.match_ids = match_ids + self.features = features.values + - if self.with_label: - self.label = self.origin_table[label].values - self.features = self.origin_table.drop(columns=[label]).values + if self.label is not None: if self.l_dtype: self.label = self.label.astype(self.l_dtype) @@ -160,10 +188,9 @@ def load(self, file_path): if self.flatten_label: self.label = self.label.flatten() - + else: self.label = None - self.features = self.origin_table.values if self.f_dtype: self.features = self.features.astype(self.f_dtype) @@ -174,8 +201,23 @@ def get_classes(self): else: raise ValueError("no label found, please check if self.label is set") - def get_sample_ids(self): + def get_sample_ids(self) -> pd.DataFrame: return self.sample_ids - def get_match_ids(self): + def get_match_ids(self) -> pd.DataFrame: return self.match_ids + + def get_sample_id_name(self) -> str: + if self.sample_ids is not None and isinstance(self.sample_ids, pd.DataFrame): + return self.sample_ids.columns[0] + else: + raise ValueError('Cannot get sample id name') + + def get_match_id_name(self) -> str: + if self.match_ids is not None and isinstance(self.match_ids, pd.DataFrame): + return self.match_ids.columns[0] + else: + raise ValueError('Cannot get match id name') + + def has_label(self) -> bool: + return self.label is not None diff --git a/python/fate/ml/nn/trainer/trainer_base.py b/python/fate/ml/nn/trainer/trainer_base.py index c7887ddb8b..208d782152 100644 --- a/python/fate/ml/nn/trainer/trainer_base.py +++ b/python/fate/ml/nn/trainer/trainer_base.py @@ -32,8 +32,6 @@ transformers_logging.disable_default_handler() transformers_logging.enable_propagation() logger = logging.getLogger(__name__) -# trainer.logger = logging.getLogger("transformers trainer") -# trainer_callback.logger = logger def time_decorator(descr=""): diff --git a/python/fate/ml/utils/predict_format.py b/python/fate/ml/utils/predict_format.py deleted file mode 100644 index 994c622078..0000000000 --- a/python/fate/ml/utils/predict_format.py +++ /dev/null @@ -1,84 +0,0 @@ -import pandas as pd -from fate.arch.dataframe import PandasReader -import numpy as np - - -TRAIN_SET = 'train_set' -VALIDATE_SET = 'validate_set' -TEST_SET = 'test_set' -LABEL = "label" -PREDICT_LABEL = "predict_result" -PREDICT_SCORE = "predict_score" -PREDICT_DETAIL = "predict_detail" -TYPE = "type" - -# TASK TYPE -BINARY = 'binary' -MULTI = 'multi' -REGRESSION = 'regression' -OTHER = 'other' - - -def add_ids(df: pd.DataFrame, match_id: pd.DataFrame, sample_id:pd.DataFrame): - df = pd.concat([df, match_id, sample_id], axis=1) - return df - - -def add_dataset_type(df: pd.DataFrame, ds_type): - - assert ds_type in [TRAIN_SET, VALIDATE_SET, TEST_SET], 'ds_type must be one of {}, but got {}'.format([TRAIN_SET, VALIDATE_SET, TEST_SET], ds_type) - df[TYPE] = ds_type - return df - - -def to_fate_df(ctx, sample_id_name, match_id_name, result_df): - - if LABEL in result_df: - reader = PandasReader(sample_id_name=sample_id_name, match_id_name=match_id_name, label_name=LABEL, dtype="object") - else: - reader = PandasReader(sample_id_name=sample_id_name, match_id_name=match_id_name, dtype="object") - data = reader.to_frame(ctx, result_df) - return data - - -def std_output_df(task_type, pred: np.array, label: np.array=None, threshold=0.5, classes: list = None): - - assert task_type in [BINARY, MULTI, REGRESSION, OTHER], 'task_type must be one of {} as a std task, but got {}'.format([BINARY, MULTI, REGRESSION, OTHER], task_type) - - if task_type == BINARY: - if len(classes) == 2: - predict_score = pred - predict_result = (predict_score > threshold).astype(int) - predict_details = [{classes[0]: 1 - float(predict_score[i]), classes[1]: float(predict_score[i])} for i in range(len(predict_score))] - else: - raise ValueError('task_type is binary, but classes length is not 2: {}'.format(classes)) - - elif task_type == MULTI: - if len(classes) > 2: - predict_score = pred.max(axis=1) - predict_result = np.argmax(pred, axis=1) - predict_details = [{classes[j]: float(pred[i][j]) for j in range(len(classes))} for i in range(len(pred))] - else: - raise ValueError('task_type is multi, but classes length is not greater than 2: {}'.format(classes)) - - elif task_type == REGRESSION: - # regression task - predict_score = pred - predict_result = pred - predict_details = [{LABEL: float(pred[i])} for i in range(len(pred))] - - if label is None: - df = pd.DataFrame({ - PREDICT_SCORE: predict_score.flatten(), - PREDICT_LABEL: predict_result.flatten(), - PREDICT_DETAIL: predict_details - }) - else: - df = pd.DataFrame({ - PREDICT_SCORE: predict_score.flatten(), - PREDICT_LABEL: predict_result.flatten(), - LABEL: label.flatten(), - PREDICT_DETAIL: predict_details - }) - - return df \ No newline at end of file diff --git a/python/fate/ml/utils/predict_tools.py b/python/fate/ml/utils/predict_tools.py new file mode 100644 index 0000000000..39b199a73d --- /dev/null +++ b/python/fate/ml/utils/predict_tools.py @@ -0,0 +1,112 @@ +import json +import pandas as pd +from fate.arch.dataframe import PandasReader +import numpy as np +from typing import Union +from fate.arch.dataframe import DataFrame + + +TRAIN_SET = 'train_set' +VALIDATE_SET = 'validate_set' +TEST_SET = 'test_set' +LABEL = "label" +PREDICT_RESULT = "predict_result" +PREDICT_SCORE = "predict_score" +PREDICT_DETAIL = "predict_detail" + +# TASK TYPE +BINARY = 'binary' +MULTI = 'multi' +REGRESSION = 'regression' +OTHER = 'other' + + +def predict_detail_dict_to_str(result_dict): + return "\"" + json.dumps(result_dict).replace("\"", "\'") + "\"" + + +def add_ids(df: pd.DataFrame, match_id: pd.DataFrame, sample_id:pd.DataFrame): + df = pd.concat([df, match_id, sample_id], axis=1) + return df + + +def to_fate_df(ctx, sample_id_name, match_id_name, result_df): + + if LABEL in result_df: + reader = PandasReader(sample_id_name=sample_id_name, match_id_name=match_id_name, label_name=LABEL, dtype="object") + else: + reader = PandasReader(sample_id_name=sample_id_name, match_id_name=match_id_name, dtype="object") + data = reader.to_frame(ctx, result_df) + return data + + +def compute_predict_details(dataframe: Union[pd.DataFrame, DataFrame], task_type, classes: list = None, threshold=0.5): + + assert task_type in [BINARY, MULTI, REGRESSION, OTHER], 'task_type must be one of {} as a std task, but got {}'.format([BINARY, MULTI, REGRESSION, OTHER], task_type) + if isinstance(dataframe, DataFrame): + df = dataframe.as_pd_df() + else: + df = dataframe + + pred = df[PREDICT_SCORE].values if PREDICT_SCORE in df else None + if pred is None: + raise ValueError('pred score is not found in input dataframe') + + if task_type == BINARY and task_type == MULTI and classes is None: + raise ValueError('task_type is binary or multi, but classes is None') + + if task_type == BINARY: + if len(classes) == 2: + predict_score = np.array(pred) + predict_result = (predict_score > threshold).astype(int) + predict_details = [{classes[0]: 1 - float(predict_score[i]), classes[1]: float(predict_score[i])} for i in range(len(predict_score))] + else: + raise ValueError('task_type is binary, but classes length is not 2: {}'.format(classes)) + + elif task_type == MULTI: + if len(classes) > 2: + predict_score = np.array([max(i) for i in pred]) + predict_result = np.array([np.argmax(i) for i in pred]) + predict_details = [predict_detail_dict_to_str({classes[j]: float(pred[i][j]) for j in range(len(classes))}) for i in range(len(pred))] + else: + raise ValueError('task_type is multi, but classes length is not greater than 2: {}'.format(classes)) + + elif task_type == REGRESSION: + # regression task + predict_score = np.array(pred) + predict_result = np.array(pred) + predict_details = [{LABEL: float(pred[i])} for i in range(len(pred))] + + df[PREDICT_RESULT] = predict_result + df[PREDICT_DETAIL] = predict_details + if task_type == MULTI: + df[PREDICT_SCORE] = predict_score + + return df + + +def std_output_df(task_type, pred: np.array, label: np.array=None, threshold=0.5, classes: list = None): + + df = pd.DataFrame() + if len(pred.shape) == 1: + df[PREDICT_SCORE] = np.array(pred) + if len(pred.shape) == 2: + if pred.shape[1] == 1: + df[PREDICT_SCORE] = np.array(pred).flatten() + else: + df[PREDICT_SCORE] = np.array(pred).tolist() + else: + raise ValueError('This is not a FATE std task, pred scores shape are {}'.format(pred.shape)) + + if label is not None: + if len(label.shape) == 1: + label = label.flatten() + elif len(label.shape) == 2 and label.shape[1] == 1: + label = label.flatten() + else: + label = label.tolist() + df[LABEL] = label + + df = compute_predict_details(df, task_type, classes, threshold) + + return df \ No newline at end of file From 6adc10467bd4741c5e93692411057e5fd7416c52 Mon Sep 17 00:00:00 2001 From: cwj Date: Thu, 13 Jul 2023 19:01:17 +0800 Subject: [PATCH 37/61] Make pep8 happy Signed-off-by: cwj --- fate_client | 2 +- python/fate/components/components/homo_lr.py | 53 +- python/fate/components/components/homo_nn.py | 98 +- .../components/nn/fate_torch/base.py | 18 +- .../components/components/nn/fate_torch/nn.py | 1731 ++++++++++++----- .../components/nn/fate_torch/optim.py | 276 ++- .../fate/components/components/nn/loader.py | 59 +- .../components/components/nn/nn_runner.py | 91 +- .../components/nn/runner/default_runner.py | 236 ++- .../components/nn/runner/my_runner.py | 22 +- .../components/nn/test/test_default_runner.py | 58 +- python/fate/ml/glm/homo/lr/client.py | 262 ++- python/fate/ml/glm/homo/lr/server.py | 14 +- python/fate/ml/glm/homo/lr/test/local_test.py | 34 +- python/fate/ml/nn/algo/homo/fedavg.py | 75 +- python/fate/ml/nn/dataset/base.py | 6 +- python/fate/ml/nn/dataset/table.py | 80 +- python/fate/ml/nn/model_zoo/multi_model.py | 2 +- python/fate/ml/nn/trainer/trainer_base.py | 475 +++-- python/fate/ml/utils/_convergence.py | 7 +- python/fate/ml/utils/_model_param.py | 4 +- python/fate/ml/utils/_optimizer.py | 82 +- python/fate/ml/utils/model_io.py | 4 +- python/fate/ml/utils/model_serdes.py | 3 +- python/fate/ml/utils/predict_tools.py | 66 +- 25 files changed, 2564 insertions(+), 1194 deletions(-) diff --git a/fate_client b/fate_client index ab81005987..df88e67f88 160000 --- a/fate_client +++ b/fate_client @@ -1 +1 @@ -Subproject commit ab81005987e43a6771b6316007e5e81c20480669 +Subproject commit df88e67f88f89ba7207aa2aa54ace68187193abf diff --git a/python/fate/components/components/homo_lr.py b/python/fate/components/components/homo_lr.py index fa988d5100..96b4c00ef6 100644 --- a/python/fate/components/components/homo_lr.py +++ b/python/fate/components/components/homo_lr.py @@ -25,7 +25,6 @@ logger = logging.getLogger(__name__) - @cpn.component(roles=[GUEST, HOST, ARBITER]) def homo_lr(ctx, role): ... @@ -38,27 +37,27 @@ def train( train_data: cpn.dataframe_input(roles=[GUEST, HOST]), validate_data: cpn.dataframe_input(roles=[GUEST, HOST], optional=True), learning_rate_scheduler: cpn.parameter(type=params.lr_scheduler_param(), - default=params.LRSchedulerParam(method="constant", - scheduler_params={"factor": 1.0}), - desc="learning rate scheduler, " - "select method from {'step', 'linear', 'constant'}" + default=params.LRSchedulerParam(method="constant", + scheduler_params={"factor": 1.0}), + desc="learning rate scheduler, " + "select method from {'step', 'linear', 'constant'}" "for list of configurable arguments, " "refer to torch.optim.lr_scheduler"), epochs: cpn.parameter(type=params.conint(gt=0), default=20, - desc="max iteration num"), + desc="max iteration num"), batch_size: cpn.parameter(type=params.conint(ge=-1), default=100, - desc="batch size, " - "value less or equals to 0 means full batch"), + desc="batch size, " + "value less or equals to 0 means full batch"), optimizer: cpn.parameter(type=params.optimizer_param(), - default=params.OptimizerParam(method="sgd", penalty='l2', alpha=1.0, - optimizer_params={"lr": 1e-2, "weight_decay": 0})), + default=params.OptimizerParam(method="sgd", penalty='l2', alpha=1.0, + optimizer_params={"lr": 1e-2, "weight_decay": 0})), init_param: cpn.parameter(type=params.init_param(), - default=params.InitParam(method='random', fit_intercept=True), - desc="Model param init setting."), + default=params.InitParam(method='random', fit_intercept=True), + desc="Model param init setting."), threshold: cpn.parameter(type=params.confloat(ge=0.0, le=1.0), default=0.5, - desc="predict threshold for binary data"), + desc="predict threshold for binary data"), ovr: cpn.parameter(type=bool, default=False, - desc="enable ovr for multi-classifcation"), + desc="enable ovr for multi-classifcation"), label_num: cpn.parameter(type=params.conint(ge=2), default=None), train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), train_input_model: cpn.json_model_input(roles=[GUEST, HOST], optional=True), @@ -68,13 +67,21 @@ def train( sub_ctx = ctx.sub_ctx(consts.TRAIN) if role.is_guest or role.is_host: # is client - + logger.info('homo lr component: client start training') - logger.info('optim param {} \n init param {} \n learning rate param {}'.format(optimizer.dict(), init_param.dict(), learning_rate_scheduler.dict())) + logger.info('optim param {} \n init param {} \n learning rate param {}'.format( + optimizer.dict(), init_param.dict(), learning_rate_scheduler.dict())) + + client = HomoLRClient( + epochs=epochs, + batch_size=batch_size, + optimizer_param=optimizer.dict(), + init_param=init_param.dict(), + learning_rate_scheduler=learning_rate_scheduler.dict(), + threshold=threshold, + ovr=ovr, + label_num=label_num) - client = HomoLRClient(epochs=epochs, batch_size=batch_size, optimizer_param=optimizer.dict(), init_param=init_param.dict(), - learning_rate_scheduler=learning_rate_scheduler.dict(), threshold=threshold, ovr=ovr, label_num=label_num) - if train_input_model is not None: model_input = train_input_model.read() client.from_model(model_input) @@ -83,7 +90,7 @@ def train( validate_df = validate_data.read() if validate_data else None client.fit(sub_ctx, train_df, validate_df) model_dict = client.get_model().dict() - + train_rs = client.predict(sub_ctx, train_df) train_rs = add_dataset_type(train_rs, consts.TRAIN_SET) if validate_df: @@ -108,10 +115,10 @@ def predict( role: Role, test_data: cpn.dataframe_input(roles=[GUEST, HOST]), batch_size: cpn.parameter(type=params.conint(ge=-1), default=100, - desc="batch size, " - "value less or equals to 0 means full batch"), + desc="batch size, " + "value less or equals to 0 means full batch"), threshold: cpn.parameter(type=params.confloat(ge=0.0, le=1.0), default=0.5, - desc="predict threshold for binary data"), + desc="predict threshold for binary data"), predict_input_model: cpn.json_model_input(roles=[GUEST, HOST]), test_output_data: cpn.dataframe_output(roles=[GUEST, HOST]) ): diff --git a/python/fate/components/components/homo_nn.py b/python/fate/components/components/homo_nn.py index 4b837b4f93..b0aec26813 100644 --- a/python/fate/components/components/homo_nn.py +++ b/python/fate/components/components/homo_nn.py @@ -39,14 +39,22 @@ def prepare_runner_class(runner_module, runner_class, runner_conf, source): logger.info("runner conf is {}".format(runner_conf)) logger.info("source is {}".format(source)) if runner_module != "fate_runner": - if source == None: + if source is None: # load from default folder - runner = Loader("fate.components.components.nn.runner." + runner_module, runner_class, **runner_conf)() + runner = Loader( + "fate.components.components.nn.runner." + + runner_module, + runner_class, + **runner_conf)() else: - runner = Loader(runner_module, runner_class, source=source, **runner_conf)() - assert isinstance(runner, NNRunner), "loaded class must be a subclass of NNRunner class, but got {}".format( - type(runner) - ) + runner = Loader( + runner_module, + runner_class, + source=source, + **runner_conf)() + assert isinstance( + runner, NNRunner), "loaded class must be a subclass of NNRunner class, but got {}".format( + type(runner)) else: logger.info("using default fate runner") runner = DefaultRunner(**runner_conf) @@ -61,7 +69,7 @@ def prepare_context_and_role(runner, ctx, role, sub_ctx_name): def get_input_data(stage, cpn_input_data): - + if stage == 'train': train_data, validate_data = cpn_input_data train_data = train_data.read() @@ -69,7 +77,7 @@ def get_input_data(stage, cpn_input_data): validate_data = validate_data.read() return train_data, validate_data - + elif stage == 'predict': test_data = cpn_input_data test_data = test_data.read() @@ -82,11 +90,12 @@ def get_input_data(stage, cpn_input_data): Output functions """ + def get_model_output_conf(runner_module, - runner_class, - runner_conf, - source, - ): + runner_class, + runner_conf, + source, + ): return { "runner_module": runner_module, "runner_class": runner_class, @@ -95,14 +104,18 @@ def get_model_output_conf(runner_module, } - -def prepared_saved_conf(model_conf, runner_class, runner_module, runner_conf, source): +def prepared_saved_conf( + model_conf, + runner_class, + runner_module, + runner_conf, + source): logger.info("loaded model_conf is: {}".format(model_conf)) if "source" in model_conf: if source is None: source = model_conf["source"] - + runner_class_, runner_module_ = model_conf['runner_class'], model_conf['runner_module'] if runner_class_ == runner_class and runner_module_ == runner_module: if "runner_conf" in model_conf: @@ -111,9 +124,11 @@ def prepared_saved_conf(model_conf, runner_class, runner_module, runner_conf, so runner_conf = saved_conf logger.info("runner_conf is updated: {}".format(runner_conf)) else: - logger.warning("runner_class or runner_module is not equal to the saved model, " - "use the new runner_conf, runner_class and runner module to train the model,\ - saved module & class: {} {}, new module & class: {} {}".format(runner_module_, runner_class_, runner_module, runner_class)) + logger.warning( + "runner_class or runner_module is not equal to the saved model, " + "use the new runner_conf, runner_class and runner module to train the model,\ + saved module & class: {} {}, new module & class: {} {}".format( + runner_module_, runner_class_, runner_module, runner_class)) return runner_conf, source, runner_class, runner_module @@ -138,31 +153,36 @@ def train( train_model_input: cpn.model_directory_input(roles=[GUEST, HOST], optional=True), ): - runner: NNRunner = prepare_runner_class(runner_module, runner_class, runner_conf, source) + runner: NNRunner = prepare_runner_class( + runner_module, runner_class, runner_conf, source) prepare_context_and_role(runner, ctx, role, consts.TRAIN) if role.is_guest or role.is_host: # is client - + if train_model_input is not None: model_conf = train_model_input.get_metadata() - runner_conf, source, runner_class, runner_module = prepared_saved_conf(model_conf, runner_class, runner_module, runner_conf, source) + runner_conf, source, runner_class, runner_module = prepared_saved_conf( + model_conf, runner_class, runner_module, runner_conf, source) saved_model_path = str(train_model_input.get_directory()) else: saved_model_path = None output_dir = str(train_model_output.get_directory()) - train_data_, validate_data_ = get_input_data(consts.TRAIN, [train_data, validate_data]) + train_data_, validate_data_ = get_input_data( + consts.TRAIN, [train_data, validate_data]) runner.train(train_data_, validate_data_, output_dir, saved_model_path) logger.info('Predicting Train & Validate Data') train_pred = runner.predict(train_data_, saved_model_path) if train_pred is not None: - assert isinstance(train_pred, DataFrame), "train predict result should be a DataFrame" + assert isinstance( + train_pred, DataFrame), "train predict result should be a DataFrame" add_dataset_type(train_pred, consts.TRAIN_SET) if validate_data_ is not None: validate_pred = runner.predict(validate_data_) - assert isinstance(validate_pred, DataFrame), "validate predict result should be a DataFrame" + assert isinstance( + validate_pred, DataFrame), "validate predict result should be a DataFrame" add_dataset_type(validate_pred, consts.VALIDATE_SET) output_df = DataFrame.vstack([train_pred, validate_pred]) else: @@ -170,7 +190,8 @@ def train( logger.info('write result dataframe') train_data_output.write(output_df) else: - logger.warning("train_pred is None, It seems that the runner is not able to predict. Failed to output data") + logger.warning( + "train_pred is None, It seems that the runner is not able to predict. Failed to output data") output_conf = get_model_output_conf(runner_module, runner_class, @@ -178,19 +199,20 @@ def train( source ) train_model_output.write_metadata(output_conf) - + elif role.is_arbiter: # is server runner.train() @homo_nn.predict() def predict( - ctx, - role: Role, - test_data: cpn.dataframe_input(roles=[GUEST, HOST]), - predict_model_input: cpn.model_directory_input(roles=[GUEST, HOST]), - predict_data_output: cpn.dataframe_output(roles=[GUEST, HOST], optional=True) -): + ctx, role: Role, test_data: cpn.dataframe_input( + roles=[ + GUEST, HOST]), predict_model_input: cpn.model_directory_input( + roles=[ + GUEST, HOST]), predict_data_output: cpn.dataframe_output( + roles=[ + GUEST, HOST], optional=True)): if role.is_guest or role.is_host: # is client @@ -201,15 +223,19 @@ def predict( source = model_conf['source'] saved_model_path = str(predict_model_input.get_directory()) test_data_ = get_input_data(consts.PREDICT, test_data) - runner: NNRunner = prepare_runner_class(runner_module, runner_class, runner_conf, source) + runner: NNRunner = prepare_runner_class( + runner_module, runner_class, runner_conf, source) prepare_context_and_role(runner, ctx, role, consts.PREDICT) - test_pred = runner.predict(test_data_, saved_model_path=saved_model_path) + test_pred = runner.predict( + test_data_, saved_model_path=saved_model_path) if test_pred is not None: - assert isinstance(test_pred, DataFrame), "test predict result should be a DataFrame" + assert isinstance( + test_pred, DataFrame), "test predict result should be a DataFrame" add_dataset_type(test_pred, consts.TEST_SET) predict_data_output.write(test_pred) else: - logger.warning("test_pred is None, It seems that the runner is not able to predict. Failed to output data") + logger.warning( + "test_pred is None, It seems that the runner is not able to predict. Failed to output data") elif role.is_arbiter: # is server logger.info("arbiter skip predict") diff --git a/python/fate/components/components/nn/fate_torch/base.py b/python/fate/components/components/nn/fate_torch/base.py index 8e052ed368..9dd59dd7c0 100644 --- a/python/fate/components/components/nn/fate_torch/base.py +++ b/python/fate/components/components/nn/fate_torch/base.py @@ -4,14 +4,14 @@ import json - def convert_tuples_to_lists(data): if isinstance(data, tuple): return list(data) elif isinstance(data, list): return [convert_tuples_to_lists(item) for item in data] elif isinstance(data, dict): - return {key: convert_tuples_to_lists(value) for key, value in data.items()} + return {key: convert_tuples_to_lists( + value) for key, value in data.items()} else: return data @@ -24,9 +24,9 @@ def __init__(self): self.optimizer = None def to_dict(self): - ret_dict ={ + ret_dict = { 'module_name': 'torch.nn', - 'item_name': str(type(self).__name__), + 'item_name': str(type(self).__name__), 'kwargs': convert_tuples_to_lists(self.param_dict) } return ret_dict @@ -39,9 +39,9 @@ def __init__(self): self.torch_class = None def to_dict(self): - ret_dict ={ + ret_dict = { 'module_name': 'torch.optim', - 'item_name': type(self).__name__, + 'item_name': type(self).__name__, 'kwargs': convert_tuples_to_lists(self.param_dict) } return ret_dict @@ -103,14 +103,12 @@ def to_dict(self): ordered_name = idx layer_confs[ordered_name] = self._modules[k].to_dict() idx += 1 - ret_dict ={ + ret_dict = { 'module_name': 'fate.components.components.nn.fate_torch.base', - 'item_name': load_seq.__name__, + 'item_name': load_seq.__name__, 'kwargs': {'seq_conf': layer_confs} } return ret_dict def to_json(self): return json.dumps(self.to_dict(), indent=4) - - \ No newline at end of file diff --git a/python/fate/components/components/nn/fate_torch/nn.py b/python/fate/components/components/nn/fate_torch/nn.py index c7f255dd72..fa0ea5bb68 100644 --- a/python/fate/components/components/nn/fate_torch/nn.py +++ b/python/fate/components/components/nn/fate_torch/nn.py @@ -3,8 +3,16 @@ class Bilinear(nn.modules.linear.Bilinear, FateTorch): - - def __init__(self, in1_features, in2_features, out_features, bias=True, device=None, dtype=None, **kwargs): + + def __init__( + self, + in1_features, + in2_features, + out_features, + bias=True, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['bias'] = bias self.param_dict['device'] = device @@ -14,19 +22,25 @@ def __init__(self, in1_features, in2_features, out_features, bias=True, device=N self.param_dict['out_features'] = out_features self.param_dict.update(kwargs) nn.modules.linear.Bilinear.__init__(self, **self.param_dict) - - + + class Identity(nn.modules.linear.Identity, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.linear.Identity.__init__(self, **self.param_dict) - - + + class LazyLinear(nn.modules.linear.LazyLinear, FateTorch): - - def __init__(self, out_features, bias=True, device=None, dtype=None, **kwargs): + + def __init__( + self, + out_features, + bias=True, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['bias'] = bias self.param_dict['device'] = device @@ -34,11 +48,18 @@ def __init__(self, out_features, bias=True, device=None, dtype=None, **kwargs): self.param_dict['out_features'] = out_features self.param_dict.update(kwargs) nn.modules.linear.LazyLinear.__init__(self, **self.param_dict) - - + + class Linear(nn.modules.linear.Linear, FateTorch): - - def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, **kwargs): + + def __init__( + self, + in_features, + out_features, + bias=True, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['bias'] = bias self.param_dict['device'] = device @@ -47,11 +68,20 @@ def __init__(self, in_features, out_features, bias=True, device=None, dtype=None self.param_dict['out_features'] = out_features self.param_dict.update(kwargs) nn.modules.linear.Linear.__init__(self, **self.param_dict) - - -class NonDynamicallyQuantizableLinear(nn.modules.linear.NonDynamicallyQuantizableLinear, FateTorch): - - def __init__(self, in_features, out_features, bias=True, device=None, dtype=None, **kwargs): + + +class NonDynamicallyQuantizableLinear( + nn.modules.linear.NonDynamicallyQuantizableLinear, + FateTorch): + + def __init__( + self, + in_features, + out_features, + bias=True, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['bias'] = bias self.param_dict['device'] = device @@ -59,20 +89,28 @@ def __init__(self, in_features, out_features, bias=True, device=None, dtype=None self.param_dict['in_features'] = in_features self.param_dict['out_features'] = out_features self.param_dict.update(kwargs) - nn.modules.linear.NonDynamicallyQuantizableLinear.__init__(self, **self.param_dict) - - + nn.modules.linear.NonDynamicallyQuantizableLinear.__init__( + self, **self.param_dict) + + class GRU(nn.modules.rnn.GRU, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.rnn.GRU.__init__(self, **self.param_dict) - - + + class GRUCell(nn.modules.rnn.GRUCell, FateTorch): - - def __init__(self, input_size, hidden_size, bias=True, device=None, dtype=None, **kwargs): + + def __init__( + self, + input_size, + hidden_size, + bias=True, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['bias'] = bias self.param_dict['device'] = device @@ -81,19 +119,26 @@ def __init__(self, input_size, hidden_size, bias=True, device=None, dtype=None, self.param_dict['hidden_size'] = hidden_size self.param_dict.update(kwargs) nn.modules.rnn.GRUCell.__init__(self, **self.param_dict) - - + + class LSTM(nn.modules.rnn.LSTM, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.rnn.LSTM.__init__(self, **self.param_dict) - - + + class LSTMCell(nn.modules.rnn.LSTMCell, FateTorch): - - def __init__(self, input_size, hidden_size, bias=True, device=None, dtype=None, **kwargs): + + def __init__( + self, + input_size, + hidden_size, + bias=True, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['bias'] = bias self.param_dict['device'] = device @@ -102,19 +147,32 @@ def __init__(self, input_size, hidden_size, bias=True, device=None, dtype=None, self.param_dict['hidden_size'] = hidden_size self.param_dict.update(kwargs) nn.modules.rnn.LSTMCell.__init__(self, **self.param_dict) - - + + class RNN(nn.modules.rnn.RNN, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.rnn.RNN.__init__(self, **self.param_dict) - - + + class RNNBase(nn.modules.rnn.RNNBase, FateTorch): - - def __init__(self, mode, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None, **kwargs): + + def __init__( + self, + mode, + input_size, + hidden_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0.0, + bidirectional=False, + proj_size=0, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['num_layers'] = num_layers self.param_dict['bias'] = bias @@ -129,11 +187,19 @@ def __init__(self, mode, input_size, hidden_size, num_layers=1, bias=True, batch self.param_dict['hidden_size'] = hidden_size self.param_dict.update(kwargs) nn.modules.rnn.RNNBase.__init__(self, **self.param_dict) - - + + class RNNCell(nn.modules.rnn.RNNCell, FateTorch): - - def __init__(self, input_size, hidden_size, bias=True, nonlinearity='tanh', device=None, dtype=None, **kwargs): + + def __init__( + self, + input_size, + hidden_size, + bias=True, + nonlinearity='tanh', + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['bias'] = bias self.param_dict['nonlinearity'] = nonlinearity @@ -143,11 +209,19 @@ def __init__(self, input_size, hidden_size, bias=True, nonlinearity='tanh', devi self.param_dict['hidden_size'] = hidden_size self.param_dict.update(kwargs) nn.modules.rnn.RNNCell.__init__(self, **self.param_dict) - - + + class RNNCellBase(nn.modules.rnn.RNNCellBase, FateTorch): - - def __init__(self, input_size, hidden_size, bias, num_chunks, device=None, dtype=None, **kwargs): + + def __init__( + self, + input_size, + hidden_size, + bias, + num_chunks, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['device'] = device self.param_dict['dtype'] = dtype @@ -157,11 +231,23 @@ def __init__(self, input_size, hidden_size, bias, num_chunks, device=None, dtype self.param_dict['num_chunks'] = num_chunks self.param_dict.update(kwargs) nn.modules.rnn.RNNCellBase.__init__(self, **self.param_dict) - - + + class Embedding(nn.modules.sparse.Embedding, FateTorch): - - def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, device=None, dtype=None, **kwargs): + + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['padding_idx'] = padding_idx self.param_dict['max_norm'] = max_norm @@ -175,11 +261,25 @@ def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=Non self.param_dict['embedding_dim'] = embedding_dim self.param_dict.update(kwargs) nn.modules.sparse.Embedding.__init__(self, **self.param_dict) - - + + class EmbeddingBag(nn.modules.sparse.EmbeddingBag, FateTorch): - - def __init__(self, num_embeddings, embedding_dim, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, mode='mean', sparse=False, _weight=None, include_last_offset=False, padding_idx=None, device=None, dtype=None, **kwargs): + + def __init__( + self, + num_embeddings, + embedding_dim, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + mode='mean', + sparse=False, + _weight=None, + include_last_offset=False, + padding_idx=None, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['max_norm'] = max_norm self.param_dict['norm_type'] = norm_type @@ -195,146 +295,154 @@ def __init__(self, num_embeddings, embedding_dim, max_norm=None, norm_type=2.0, self.param_dict['embedding_dim'] = embedding_dim self.param_dict.update(kwargs) nn.modules.sparse.EmbeddingBag.__init__(self, **self.param_dict) - - + + class AlphaDropout(nn.modules.dropout.AlphaDropout, FateTorch): - + def __init__(self, p=0.5, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['p'] = p self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.dropout.AlphaDropout.__init__(self, **self.param_dict) - - + + class Dropout(nn.modules.dropout.Dropout, FateTorch): - + def __init__(self, p=0.5, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['p'] = p self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.dropout.Dropout.__init__(self, **self.param_dict) - - + + class Dropout1d(nn.modules.dropout.Dropout1d, FateTorch): - + def __init__(self, p=0.5, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['p'] = p self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.dropout.Dropout1d.__init__(self, **self.param_dict) - - + + class Dropout2d(nn.modules.dropout.Dropout2d, FateTorch): - + def __init__(self, p=0.5, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['p'] = p self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.dropout.Dropout2d.__init__(self, **self.param_dict) - - + + class Dropout3d(nn.modules.dropout.Dropout3d, FateTorch): - + def __init__(self, p=0.5, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['p'] = p self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.dropout.Dropout3d.__init__(self, **self.param_dict) - - + + class FeatureAlphaDropout(nn.modules.dropout.FeatureAlphaDropout, FateTorch): - + def __init__(self, p=0.5, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['p'] = p self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) - nn.modules.dropout.FeatureAlphaDropout.__init__(self, **self.param_dict) - - + nn.modules.dropout.FeatureAlphaDropout.__init__( + self, **self.param_dict) + + class _DropoutNd(nn.modules.dropout._DropoutNd, FateTorch): - + def __init__(self, p=0.5, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['p'] = p self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.dropout._DropoutNd.__init__(self, **self.param_dict) - - + + class CELU(nn.modules.activation.CELU, FateTorch): - + def __init__(self, alpha=1.0, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['alpha'] = alpha self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.activation.CELU.__init__(self, **self.param_dict) - - + + class ELU(nn.modules.activation.ELU, FateTorch): - + def __init__(self, alpha=1.0, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['alpha'] = alpha self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.activation.ELU.__init__(self, **self.param_dict) - - + + class GELU(nn.modules.activation.GELU, FateTorch): - + def __init__(self, approximate='none', **kwargs): FateTorch.__init__(self) self.param_dict['approximate'] = approximate self.param_dict.update(kwargs) nn.modules.activation.GELU.__init__(self, **self.param_dict) - - + + class GLU(nn.modules.activation.GLU, FateTorch): - + def __init__(self, dim=-1, **kwargs): FateTorch.__init__(self) self.param_dict['dim'] = dim self.param_dict.update(kwargs) nn.modules.activation.GLU.__init__(self, **self.param_dict) - - + + class Hardshrink(nn.modules.activation.Hardshrink, FateTorch): - + def __init__(self, lambd=0.5, **kwargs): FateTorch.__init__(self) self.param_dict['lambd'] = lambd self.param_dict.update(kwargs) nn.modules.activation.Hardshrink.__init__(self, **self.param_dict) - - + + class Hardsigmoid(nn.modules.activation.Hardsigmoid, FateTorch): - + def __init__(self, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.activation.Hardsigmoid.__init__(self, **self.param_dict) - - + + class Hardswish(nn.modules.activation.Hardswish, FateTorch): - + def __init__(self, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.activation.Hardswish.__init__(self, **self.param_dict) - - + + class Hardtanh(nn.modules.activation.Hardtanh, FateTorch): - - def __init__(self, min_val=-1.0, max_val=1.0, inplace=False, min_value=None, max_value=None, **kwargs): + + def __init__( + self, + min_val=-1.0, + max_val=1.0, + inplace=False, + min_value=None, + max_value=None, + **kwargs): FateTorch.__init__(self) self.param_dict['min_val'] = min_val self.param_dict['max_val'] = max_val @@ -343,47 +451,60 @@ def __init__(self, min_val=-1.0, max_val=1.0, inplace=False, min_value=None, max self.param_dict['max_value'] = max_value self.param_dict.update(kwargs) nn.modules.activation.Hardtanh.__init__(self, **self.param_dict) - - + + class LeakyReLU(nn.modules.activation.LeakyReLU, FateTorch): - + def __init__(self, negative_slope=0.01, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['negative_slope'] = negative_slope self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.activation.LeakyReLU.__init__(self, **self.param_dict) - - + + class LogSigmoid(nn.modules.activation.LogSigmoid, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.activation.LogSigmoid.__init__(self, **self.param_dict) - - + + class LogSoftmax(nn.modules.activation.LogSoftmax, FateTorch): - + def __init__(self, dim=None, **kwargs): FateTorch.__init__(self) self.param_dict['dim'] = dim self.param_dict.update(kwargs) nn.modules.activation.LogSoftmax.__init__(self, **self.param_dict) - - + + class Mish(nn.modules.activation.Mish, FateTorch): - + def __init__(self, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.activation.Mish.__init__(self, **self.param_dict) - - + + class MultiheadAttention(nn.modules.activation.MultiheadAttention, FateTorch): - - def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None, **kwargs): + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['dropout'] = dropout self.param_dict['bias'] = bias @@ -397,12 +518,19 @@ def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=Fal self.param_dict['embed_dim'] = embed_dim self.param_dict['num_heads'] = num_heads self.param_dict.update(kwargs) - nn.modules.activation.MultiheadAttention.__init__(self, **self.param_dict) - - + nn.modules.activation.MultiheadAttention.__init__( + self, **self.param_dict) + + class PReLU(nn.modules.activation.PReLU, FateTorch): - - def __init__(self, num_parameters=1, init=0.25, device=None, dtype=None, **kwargs): + + def __init__( + self, + num_parameters=1, + init=0.25, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['num_parameters'] = num_parameters self.param_dict['init'] = init @@ -410,134 +538,139 @@ def __init__(self, num_parameters=1, init=0.25, device=None, dtype=None, **kwarg self.param_dict['dtype'] = dtype self.param_dict.update(kwargs) nn.modules.activation.PReLU.__init__(self, **self.param_dict) - - + + class RReLU(nn.modules.activation.RReLU, FateTorch): - - def __init__(self, lower=0.125, upper=0.3333333333333333, inplace=False, **kwargs): + + def __init__( + self, + lower=0.125, + upper=0.3333333333333333, + inplace=False, + **kwargs): FateTorch.__init__(self) self.param_dict['lower'] = lower self.param_dict['upper'] = upper self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.activation.RReLU.__init__(self, **self.param_dict) - - + + class ReLU(nn.modules.activation.ReLU, FateTorch): - + def __init__(self, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.activation.ReLU.__init__(self, **self.param_dict) - - + + class ReLU6(nn.modules.activation.ReLU6, FateTorch): - + def __init__(self, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.activation.ReLU6.__init__(self, **self.param_dict) - - + + class SELU(nn.modules.activation.SELU, FateTorch): - + def __init__(self, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.activation.SELU.__init__(self, **self.param_dict) - - + + class SiLU(nn.modules.activation.SiLU, FateTorch): - + def __init__(self, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['inplace'] = inplace self.param_dict.update(kwargs) nn.modules.activation.SiLU.__init__(self, **self.param_dict) - - + + class Sigmoid(nn.modules.activation.Sigmoid, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.activation.Sigmoid.__init__(self, **self.param_dict) - - + + class Softmax(nn.modules.activation.Softmax, FateTorch): - + def __init__(self, dim=None, **kwargs): FateTorch.__init__(self) self.param_dict['dim'] = dim self.param_dict.update(kwargs) nn.modules.activation.Softmax.__init__(self, **self.param_dict) - - + + class Softmax2d(nn.modules.activation.Softmax2d, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.activation.Softmax2d.__init__(self, **self.param_dict) - - + + class Softmin(nn.modules.activation.Softmin, FateTorch): - + def __init__(self, dim=None, **kwargs): FateTorch.__init__(self) self.param_dict['dim'] = dim self.param_dict.update(kwargs) nn.modules.activation.Softmin.__init__(self, **self.param_dict) - - + + class Softplus(nn.modules.activation.Softplus, FateTorch): - + def __init__(self, beta=1, threshold=20, **kwargs): FateTorch.__init__(self) self.param_dict['beta'] = beta self.param_dict['threshold'] = threshold self.param_dict.update(kwargs) nn.modules.activation.Softplus.__init__(self, **self.param_dict) - - + + class Softshrink(nn.modules.activation.Softshrink, FateTorch): - + def __init__(self, lambd=0.5, **kwargs): FateTorch.__init__(self) self.param_dict['lambd'] = lambd self.param_dict.update(kwargs) nn.modules.activation.Softshrink.__init__(self, **self.param_dict) - - + + class Softsign(nn.modules.activation.Softsign, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.activation.Softsign.__init__(self, **self.param_dict) - - + + class Tanh(nn.modules.activation.Tanh, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.activation.Tanh.__init__(self, **self.param_dict) - - + + class Tanhshrink(nn.modules.activation.Tanhshrink, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.activation.Tanhshrink.__init__(self, **self.param_dict) - - + + class Threshold(nn.modules.activation.Threshold, FateTorch): - + def __init__(self, threshold, value, inplace=False, **kwargs): FateTorch.__init__(self) self.param_dict['inplace'] = inplace @@ -545,11 +678,24 @@ def __init__(self, threshold, value, inplace=False, **kwargs): self.param_dict['value'] = value self.param_dict.update(kwargs) nn.modules.activation.Threshold.__init__(self, **self.param_dict) - - + + class Conv1d(nn.modules.conv.Conv1d, FateTorch): - - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None, **kwargs): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -564,11 +710,24 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.conv.Conv1d.__init__(self, **self.param_dict) - - + + class Conv2d(nn.modules.conv.Conv2d, FateTorch): - - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None, **kwargs): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -583,11 +742,24 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.conv.Conv2d.__init__(self, **self.param_dict) - - + + class Conv3d(nn.modules.conv.Conv3d, FateTorch): - - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None, **kwargs): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -602,11 +774,25 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.conv.Conv3d.__init__(self, **self.param_dict) - - + + class ConvTranspose1d(nn.modules.conv.ConvTranspose1d, FateTorch): - - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', device=None, dtype=None, **kwargs): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -622,11 +808,25 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.conv.ConvTranspose1d.__init__(self, **self.param_dict) - - + + class ConvTranspose2d(nn.modules.conv.ConvTranspose2d, FateTorch): - - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', device=None, dtype=None, **kwargs): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -642,11 +842,25 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.conv.ConvTranspose2d.__init__(self, **self.param_dict) - - + + class ConvTranspose3d(nn.modules.conv.ConvTranspose3d, FateTorch): - - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', device=None, dtype=None, **kwargs): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -662,11 +876,23 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.conv.ConvTranspose3d.__init__(self, **self.param_dict) - - + + class LazyConv1d(nn.modules.conv.LazyConv1d, FateTorch): - - def __init__(self, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None, **kwargs): + + def __init__( + self, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -680,11 +906,23 @@ def __init__(self, out_channels, kernel_size, stride=1, padding=0, dilation=1, g self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.conv.LazyConv1d.__init__(self, **self.param_dict) - - + + class LazyConv2d(nn.modules.conv.LazyConv2d, FateTorch): - - def __init__(self, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None, **kwargs): + + def __init__( + self, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -698,11 +936,23 @@ def __init__(self, out_channels, kernel_size, stride=1, padding=0, dilation=1, g self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.conv.LazyConv2d.__init__(self, **self.param_dict) - - + + class LazyConv3d(nn.modules.conv.LazyConv3d, FateTorch): - - def __init__(self, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None, **kwargs): + + def __init__( + self, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -716,11 +966,24 @@ def __init__(self, out_channels, kernel_size, stride=1, padding=0, dilation=1, g self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.conv.LazyConv3d.__init__(self, **self.param_dict) - - + + class LazyConvTranspose1d(nn.modules.conv.LazyConvTranspose1d, FateTorch): - - def __init__(self, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', device=None, dtype=None, **kwargs): + + def __init__( + self, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -735,11 +998,24 @@ def __init__(self, out_channels, kernel_size, stride=1, padding=0, output_paddin self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.conv.LazyConvTranspose1d.__init__(self, **self.param_dict) - - + + class LazyConvTranspose2d(nn.modules.conv.LazyConvTranspose2d, FateTorch): - - def __init__(self, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', device=None, dtype=None, **kwargs): + + def __init__( + self, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -754,11 +1030,24 @@ def __init__(self, out_channels, kernel_size, stride=1, padding=0, output_paddin self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.conv.LazyConvTranspose2d.__init__(self, **self.param_dict) - - + + class LazyConvTranspose3d(nn.modules.conv.LazyConvTranspose3d, FateTorch): - - def __init__(self, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', device=None, dtype=None, **kwargs): + + def __init__( + self, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -773,11 +1062,26 @@ def __init__(self, out_channels, kernel_size, stride=1, padding=0, output_paddin self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.conv.LazyConvTranspose3d.__init__(self, **self.param_dict) - - + + class _ConvNd(nn.modules.conv._ConvNd, FateTorch): - - def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode, device=None, dtype=None, **kwargs): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['device'] = device self.param_dict['dtype'] = dtype @@ -794,19 +1098,34 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dila self.param_dict['padding_mode'] = padding_mode self.param_dict.update(kwargs) nn.modules.conv._ConvNd.__init__(self, **self.param_dict) - - + + class _ConvTransposeMixin(nn.modules.conv._ConvTransposeMixin, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.conv._ConvTransposeMixin.__init__(self, **self.param_dict) - - + + class _ConvTransposeNd(nn.modules.conv._ConvTransposeNd, FateTorch): - - def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode, device=None, dtype=None, **kwargs): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['device'] = device self.param_dict['dtype'] = dtype @@ -823,19 +1142,34 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dila self.param_dict['padding_mode'] = padding_mode self.param_dict.update(kwargs) nn.modules.conv._ConvTransposeNd.__init__(self, **self.param_dict) - - + + class _LazyConvXdMixin(nn.modules.conv._LazyConvXdMixin, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.conv._LazyConvXdMixin.__init__(self, **self.param_dict) - - + + class Transformer(nn.modules.transformer.Transformer, FateTorch): - - def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, custom_encoder=None, custom_decoder=None, layer_norm_eps=1e-05, batch_first=False, norm_first=False, device=None, dtype=None, **kwargs): + + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + custom_encoder=None, + custom_decoder=None, + layer_norm_eps=1e-05, + batch_first=False, + norm_first=False, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['d_model'] = d_model self.param_dict['nhead'] = nhead @@ -852,22 +1186,36 @@ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layer self.param_dict['dtype'] = dtype self.param_dict.update(kwargs) nn.modules.transformer.Transformer.__init__(self, **self.param_dict) - - + + class TransformerDecoder(nn.modules.transformer.TransformerDecoder, FateTorch): - + def __init__(self, decoder_layer, num_layers, norm=None, **kwargs): FateTorch.__init__(self) self.param_dict['norm'] = norm self.param_dict['decoder_layer'] = decoder_layer self.param_dict['num_layers'] = num_layers self.param_dict.update(kwargs) - nn.modules.transformer.TransformerDecoder.__init__(self, **self.param_dict) - - -class TransformerDecoderLayer(nn.modules.transformer.TransformerDecoderLayer, FateTorch): - - def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, layer_norm_eps=1e-05, batch_first=False, norm_first=False, device=None, dtype=None, **kwargs): + nn.modules.transformer.TransformerDecoder.__init__( + self, **self.param_dict) + + +class TransformerDecoderLayer( + nn.modules.transformer.TransformerDecoderLayer, + FateTorch): + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + layer_norm_eps=1e-05, + batch_first=False, + norm_first=False, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['dim_feedforward'] = dim_feedforward self.param_dict['dropout'] = dropout @@ -879,12 +1227,20 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, layer_norm self.param_dict['d_model'] = d_model self.param_dict['nhead'] = nhead self.param_dict.update(kwargs) - nn.modules.transformer.TransformerDecoderLayer.__init__(self, **self.param_dict) - - + nn.modules.transformer.TransformerDecoderLayer.__init__( + self, **self.param_dict) + + class TransformerEncoder(nn.modules.transformer.TransformerEncoder, FateTorch): - - def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True, **kwargs): + + def __init__( + self, + encoder_layer, + num_layers, + norm=None, + enable_nested_tensor=True, + mask_check=True, + **kwargs): FateTorch.__init__(self) self.param_dict['norm'] = norm self.param_dict['enable_nested_tensor'] = enable_nested_tensor @@ -892,12 +1248,26 @@ def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=Tr self.param_dict['encoder_layer'] = encoder_layer self.param_dict['num_layers'] = num_layers self.param_dict.update(kwargs) - nn.modules.transformer.TransformerEncoder.__init__(self, **self.param_dict) - - -class TransformerEncoderLayer(nn.modules.transformer.TransformerEncoderLayer, FateTorch): - - def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, layer_norm_eps=1e-05, batch_first=False, norm_first=False, device=None, dtype=None, **kwargs): + nn.modules.transformer.TransformerEncoder.__init__( + self, **self.param_dict) + + +class TransformerEncoderLayer( + nn.modules.transformer.TransformerEncoderLayer, + FateTorch): + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + layer_norm_eps=1e-05, + batch_first=False, + norm_first=False, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['dim_feedforward'] = dim_feedforward self.param_dict['dropout'] = dropout @@ -909,69 +1279,77 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, layer_norm self.param_dict['d_model'] = d_model self.param_dict['nhead'] = nhead self.param_dict.update(kwargs) - nn.modules.transformer.TransformerEncoderLayer.__init__(self, **self.param_dict) - - + nn.modules.transformer.TransformerEncoderLayer.__init__( + self, **self.param_dict) + + class AdaptiveAvgPool1d(nn.modules.pooling.AdaptiveAvgPool1d, FateTorch): - + def __init__(self, output_size, **kwargs): FateTorch.__init__(self) self.param_dict['output_size'] = output_size self.param_dict.update(kwargs) nn.modules.pooling.AdaptiveAvgPool1d.__init__(self, **self.param_dict) - - + + class AdaptiveAvgPool2d(nn.modules.pooling.AdaptiveAvgPool2d, FateTorch): - + def __init__(self, output_size, **kwargs): FateTorch.__init__(self) self.param_dict['output_size'] = output_size self.param_dict.update(kwargs) nn.modules.pooling.AdaptiveAvgPool2d.__init__(self, **self.param_dict) - - + + class AdaptiveAvgPool3d(nn.modules.pooling.AdaptiveAvgPool3d, FateTorch): - + def __init__(self, output_size, **kwargs): FateTorch.__init__(self) self.param_dict['output_size'] = output_size self.param_dict.update(kwargs) nn.modules.pooling.AdaptiveAvgPool3d.__init__(self, **self.param_dict) - - + + class AdaptiveMaxPool1d(nn.modules.pooling.AdaptiveMaxPool1d, FateTorch): - + def __init__(self, output_size, return_indices=False, **kwargs): FateTorch.__init__(self) self.param_dict['return_indices'] = return_indices self.param_dict['output_size'] = output_size self.param_dict.update(kwargs) nn.modules.pooling.AdaptiveMaxPool1d.__init__(self, **self.param_dict) - - + + class AdaptiveMaxPool2d(nn.modules.pooling.AdaptiveMaxPool2d, FateTorch): - + def __init__(self, output_size, return_indices=False, **kwargs): FateTorch.__init__(self) self.param_dict['return_indices'] = return_indices self.param_dict['output_size'] = output_size self.param_dict.update(kwargs) nn.modules.pooling.AdaptiveMaxPool2d.__init__(self, **self.param_dict) - - + + class AdaptiveMaxPool3d(nn.modules.pooling.AdaptiveMaxPool3d, FateTorch): - + def __init__(self, output_size, return_indices=False, **kwargs): FateTorch.__init__(self) self.param_dict['return_indices'] = return_indices self.param_dict['output_size'] = output_size self.param_dict.update(kwargs) nn.modules.pooling.AdaptiveMaxPool3d.__init__(self, **self.param_dict) - - + + class AvgPool1d(nn.modules.pooling.AvgPool1d, FateTorch): - - def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, **kwargs): + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -980,11 +1358,19 @@ def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, count_i self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.pooling.AvgPool1d.__init__(self, **self.param_dict) - - + + class AvgPool2d(nn.modules.pooling.AvgPool2d, FateTorch): - - def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None, **kwargs): + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -994,11 +1380,19 @@ def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, count_i self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.pooling.AvgPool2d.__init__(self, **self.param_dict) - - + + class AvgPool3d(nn.modules.pooling.AvgPool3d, FateTorch): - - def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None, **kwargs): + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -1008,11 +1402,18 @@ def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, count_i self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.pooling.AvgPool3d.__init__(self, **self.param_dict) - - + + class FractionalMaxPool2d(nn.modules.pooling.FractionalMaxPool2d, FateTorch): - - def __init__(self, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None, **kwargs): + + def __init__( + self, + kernel_size, + output_size=None, + output_ratio=None, + return_indices=False, + _random_samples=None, + **kwargs): FateTorch.__init__(self) self.param_dict['output_size'] = output_size self.param_dict['output_ratio'] = output_ratio @@ -1020,12 +1421,20 @@ def __init__(self, kernel_size, output_size=None, output_ratio=None, return_indi self.param_dict['_random_samples'] = _random_samples self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) - nn.modules.pooling.FractionalMaxPool2d.__init__(self, **self.param_dict) - - + nn.modules.pooling.FractionalMaxPool2d.__init__( + self, **self.param_dict) + + class FractionalMaxPool3d(nn.modules.pooling.FractionalMaxPool3d, FateTorch): - - def __init__(self, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None, **kwargs): + + def __init__( + self, + kernel_size, + output_size=None, + output_ratio=None, + return_indices=False, + _random_samples=None, + **kwargs): FateTorch.__init__(self) self.param_dict['output_size'] = output_size self.param_dict['output_ratio'] = output_ratio @@ -1033,12 +1442,19 @@ def __init__(self, kernel_size, output_size=None, output_ratio=None, return_indi self.param_dict['_random_samples'] = _random_samples self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) - nn.modules.pooling.FractionalMaxPool3d.__init__(self, **self.param_dict) - - + nn.modules.pooling.FractionalMaxPool3d.__init__( + self, **self.param_dict) + + class LPPool1d(nn.modules.pooling.LPPool1d, FateTorch): - - def __init__(self, norm_type, kernel_size, stride=None, ceil_mode=False, **kwargs): + + def __init__( + self, + norm_type, + kernel_size, + stride=None, + ceil_mode=False, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['ceil_mode'] = ceil_mode @@ -1046,11 +1462,17 @@ def __init__(self, norm_type, kernel_size, stride=None, ceil_mode=False, **kwarg self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.pooling.LPPool1d.__init__(self, **self.param_dict) - - + + class LPPool2d(nn.modules.pooling.LPPool2d, FateTorch): - - def __init__(self, norm_type, kernel_size, stride=None, ceil_mode=False, **kwargs): + + def __init__( + self, + norm_type, + kernel_size, + stride=None, + ceil_mode=False, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['ceil_mode'] = ceil_mode @@ -1058,11 +1480,19 @@ def __init__(self, norm_type, kernel_size, stride=None, ceil_mode=False, **kwarg self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.pooling.LPPool2d.__init__(self, **self.param_dict) - - + + class MaxPool1d(nn.modules.pooling.MaxPool1d, FateTorch): - - def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False, **kwargs): + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + dilation=1, + return_indices=False, + ceil_mode=False, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -1072,11 +1502,19 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indic self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.pooling.MaxPool1d.__init__(self, **self.param_dict) - - + + class MaxPool2d(nn.modules.pooling.MaxPool2d, FateTorch): - - def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False, **kwargs): + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + dilation=1, + return_indices=False, + ceil_mode=False, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -1086,11 +1524,19 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indic self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.pooling.MaxPool2d.__init__(self, **self.param_dict) - - + + class MaxPool3d(nn.modules.pooling.MaxPool3d, FateTorch): - - def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False, **kwargs): + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + dilation=1, + return_indices=False, + ceil_mode=False, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -1100,10 +1546,10 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indic self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.pooling.MaxPool3d.__init__(self, **self.param_dict) - - + + class MaxUnpool1d(nn.modules.pooling.MaxUnpool1d, FateTorch): - + def __init__(self, kernel_size, stride=None, padding=0, **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride @@ -1111,10 +1557,10 @@ def __init__(self, kernel_size, stride=None, padding=0, **kwargs): self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.pooling.MaxUnpool1d.__init__(self, **self.param_dict) - - + + class MaxUnpool2d(nn.modules.pooling.MaxUnpool2d, FateTorch): - + def __init__(self, kernel_size, stride=None, padding=0, **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride @@ -1122,10 +1568,10 @@ def __init__(self, kernel_size, stride=None, padding=0, **kwargs): self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.pooling.MaxUnpool2d.__init__(self, **self.param_dict) - - + + class MaxUnpool3d(nn.modules.pooling.MaxUnpool3d, FateTorch): - + def __init__(self, kernel_size, stride=None, padding=0, **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride @@ -1133,38 +1579,44 @@ def __init__(self, kernel_size, stride=None, padding=0, **kwargs): self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.pooling.MaxUnpool3d.__init__(self, **self.param_dict) - - + + class _AdaptiveAvgPoolNd(nn.modules.pooling._AdaptiveAvgPoolNd, FateTorch): - + def __init__(self, output_size, **kwargs): FateTorch.__init__(self) self.param_dict['output_size'] = output_size self.param_dict.update(kwargs) nn.modules.pooling._AdaptiveAvgPoolNd.__init__(self, **self.param_dict) - - + + class _AdaptiveMaxPoolNd(nn.modules.pooling._AdaptiveMaxPoolNd, FateTorch): - + def __init__(self, output_size, return_indices=False, **kwargs): FateTorch.__init__(self) self.param_dict['return_indices'] = return_indices self.param_dict['output_size'] = output_size self.param_dict.update(kwargs) nn.modules.pooling._AdaptiveMaxPoolNd.__init__(self, **self.param_dict) - - + + class _AvgPoolNd(nn.modules.pooling._AvgPoolNd, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.pooling._AvgPoolNd.__init__(self, **self.param_dict) - - + + class _LPPoolNd(nn.modules.pooling._LPPoolNd, FateTorch): - - def __init__(self, norm_type, kernel_size, stride=None, ceil_mode=False, **kwargs): + + def __init__( + self, + norm_type, + kernel_size, + stride=None, + ceil_mode=False, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['ceil_mode'] = ceil_mode @@ -1172,11 +1624,19 @@ def __init__(self, norm_type, kernel_size, stride=None, ceil_mode=False, **kwarg self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.pooling._LPPoolNd.__init__(self, **self.param_dict) - - + + class _MaxPoolNd(nn.modules.pooling._MaxPoolNd, FateTorch): - - def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False, **kwargs): + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + dilation=1, + return_indices=False, + ceil_mode=False, + **kwargs): FateTorch.__init__(self) self.param_dict['stride'] = stride self.param_dict['padding'] = padding @@ -1186,19 +1646,28 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indic self.param_dict['kernel_size'] = kernel_size self.param_dict.update(kwargs) nn.modules.pooling._MaxPoolNd.__init__(self, **self.param_dict) - - + + class _MaxUnpoolNd(nn.modules.pooling._MaxUnpoolNd, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.pooling._MaxUnpoolNd.__init__(self, **self.param_dict) - - + + class BatchNorm1d(nn.modules.batchnorm.BatchNorm1d, FateTorch): - - def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None, **kwargs): + + def __init__( + self, + num_features, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['eps'] = eps self.param_dict['momentum'] = momentum @@ -1209,11 +1678,20 @@ def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_run self.param_dict['num_features'] = num_features self.param_dict.update(kwargs) nn.modules.batchnorm.BatchNorm1d.__init__(self, **self.param_dict) - - + + class BatchNorm2d(nn.modules.batchnorm.BatchNorm2d, FateTorch): - - def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None, **kwargs): + + def __init__( + self, + num_features, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['eps'] = eps self.param_dict['momentum'] = momentum @@ -1224,11 +1702,20 @@ def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_run self.param_dict['num_features'] = num_features self.param_dict.update(kwargs) nn.modules.batchnorm.BatchNorm2d.__init__(self, **self.param_dict) - - + + class BatchNorm3d(nn.modules.batchnorm.BatchNorm3d, FateTorch): - - def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None, **kwargs): + + def __init__( + self, + num_features, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['eps'] = eps self.param_dict['momentum'] = momentum @@ -1239,11 +1726,19 @@ def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_run self.param_dict['num_features'] = num_features self.param_dict.update(kwargs) nn.modules.batchnorm.BatchNorm3d.__init__(self, **self.param_dict) - - + + class LazyBatchNorm1d(nn.modules.batchnorm.LazyBatchNorm1d, FateTorch): - - def __init__(self, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None, **kwargs): + + def __init__( + self, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['eps'] = eps self.param_dict['momentum'] = momentum @@ -1253,11 +1748,19 @@ def __init__(self, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru self.param_dict['dtype'] = dtype self.param_dict.update(kwargs) nn.modules.batchnorm.LazyBatchNorm1d.__init__(self, **self.param_dict) - - + + class LazyBatchNorm2d(nn.modules.batchnorm.LazyBatchNorm2d, FateTorch): - - def __init__(self, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None, **kwargs): + + def __init__( + self, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['eps'] = eps self.param_dict['momentum'] = momentum @@ -1267,11 +1770,19 @@ def __init__(self, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru self.param_dict['dtype'] = dtype self.param_dict.update(kwargs) nn.modules.batchnorm.LazyBatchNorm2d.__init__(self, **self.param_dict) - - + + class LazyBatchNorm3d(nn.modules.batchnorm.LazyBatchNorm3d, FateTorch): - - def __init__(self, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None, **kwargs): + + def __init__( + self, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['eps'] = eps self.param_dict['momentum'] = momentum @@ -1281,11 +1792,21 @@ def __init__(self, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru self.param_dict['dtype'] = dtype self.param_dict.update(kwargs) nn.modules.batchnorm.LazyBatchNorm3d.__init__(self, **self.param_dict) - - + + class SyncBatchNorm(nn.modules.batchnorm.SyncBatchNorm, FateTorch): - - def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, process_group=None, device=None, dtype=None, **kwargs): + + def __init__( + self, + num_features, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + process_group=None, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['eps'] = eps self.param_dict['momentum'] = momentum @@ -1297,11 +1818,20 @@ def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_run self.param_dict['num_features'] = num_features self.param_dict.update(kwargs) nn.modules.batchnorm.SyncBatchNorm.__init__(self, **self.param_dict) - - + + class _BatchNorm(nn.modules.batchnorm._BatchNorm, FateTorch): - - def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None, **kwargs): + + def __init__( + self, + num_features, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['eps'] = eps self.param_dict['momentum'] = momentum @@ -1312,11 +1842,19 @@ def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_run self.param_dict['num_features'] = num_features self.param_dict.update(kwargs) nn.modules.batchnorm._BatchNorm.__init__(self, **self.param_dict) - - + + class _LazyNormBase(nn.modules.batchnorm._LazyNormBase, FateTorch): - - def __init__(self, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None, **kwargs): + + def __init__( + self, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['eps'] = eps self.param_dict['momentum'] = momentum @@ -1326,11 +1864,20 @@ def __init__(self, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru self.param_dict['dtype'] = dtype self.param_dict.update(kwargs) nn.modules.batchnorm._LazyNormBase.__init__(self, **self.param_dict) - - + + class _NormBase(nn.modules.batchnorm._NormBase, FateTorch): - - def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None, **kwargs): + + def __init__( + self, + num_features, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): FateTorch.__init__(self) self.param_dict['eps'] = eps self.param_dict['momentum'] = momentum @@ -1341,129 +1888,135 @@ def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_run self.param_dict['num_features'] = num_features self.param_dict.update(kwargs) nn.modules.batchnorm._NormBase.__init__(self, **self.param_dict) - - + + class ConstantPad1d(nn.modules.padding.ConstantPad1d, FateTorch): - + def __init__(self, padding, value, **kwargs): FateTorch.__init__(self) self.param_dict['padding'] = padding self.param_dict['value'] = value self.param_dict.update(kwargs) nn.modules.padding.ConstantPad1d.__init__(self, **self.param_dict) - - + + class ConstantPad2d(nn.modules.padding.ConstantPad2d, FateTorch): - + def __init__(self, padding, value, **kwargs): FateTorch.__init__(self) self.param_dict['padding'] = padding self.param_dict['value'] = value self.param_dict.update(kwargs) nn.modules.padding.ConstantPad2d.__init__(self, **self.param_dict) - - + + class ConstantPad3d(nn.modules.padding.ConstantPad3d, FateTorch): - + def __init__(self, padding, value, **kwargs): FateTorch.__init__(self) self.param_dict['padding'] = padding self.param_dict['value'] = value self.param_dict.update(kwargs) nn.modules.padding.ConstantPad3d.__init__(self, **self.param_dict) - - + + class ReflectionPad1d(nn.modules.padding.ReflectionPad1d, FateTorch): - + def __init__(self, padding, **kwargs): FateTorch.__init__(self) self.param_dict['padding'] = padding self.param_dict.update(kwargs) nn.modules.padding.ReflectionPad1d.__init__(self, **self.param_dict) - - + + class ReflectionPad2d(nn.modules.padding.ReflectionPad2d, FateTorch): - + def __init__(self, padding, **kwargs): FateTorch.__init__(self) self.param_dict['padding'] = padding self.param_dict.update(kwargs) nn.modules.padding.ReflectionPad2d.__init__(self, **self.param_dict) - - + + class ReflectionPad3d(nn.modules.padding.ReflectionPad3d, FateTorch): - + def __init__(self, padding, **kwargs): FateTorch.__init__(self) self.param_dict['padding'] = padding self.param_dict.update(kwargs) nn.modules.padding.ReflectionPad3d.__init__(self, **self.param_dict) - - + + class ReplicationPad1d(nn.modules.padding.ReplicationPad1d, FateTorch): - + def __init__(self, padding, **kwargs): FateTorch.__init__(self) self.param_dict['padding'] = padding self.param_dict.update(kwargs) nn.modules.padding.ReplicationPad1d.__init__(self, **self.param_dict) - - + + class ReplicationPad2d(nn.modules.padding.ReplicationPad2d, FateTorch): - + def __init__(self, padding, **kwargs): FateTorch.__init__(self) self.param_dict['padding'] = padding self.param_dict.update(kwargs) nn.modules.padding.ReplicationPad2d.__init__(self, **self.param_dict) - - + + class ReplicationPad3d(nn.modules.padding.ReplicationPad3d, FateTorch): - + def __init__(self, padding, **kwargs): FateTorch.__init__(self) self.param_dict['padding'] = padding self.param_dict.update(kwargs) nn.modules.padding.ReplicationPad3d.__init__(self, **self.param_dict) - - + + class ZeroPad2d(nn.modules.padding.ZeroPad2d, FateTorch): - + def __init__(self, padding, **kwargs): FateTorch.__init__(self) self.param_dict['padding'] = padding self.param_dict.update(kwargs) nn.modules.padding.ZeroPad2d.__init__(self, **self.param_dict) - - + + class _ConstantPadNd(nn.modules.padding._ConstantPadNd, FateTorch): - + def __init__(self, value, **kwargs): FateTorch.__init__(self) self.param_dict['value'] = value self.param_dict.update(kwargs) nn.modules.padding._ConstantPadNd.__init__(self, **self.param_dict) - - + + class _ReflectionPadNd(nn.modules.padding._ReflectionPadNd, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.padding._ReflectionPadNd.__init__(self, **self.param_dict) - - + + class _ReplicationPadNd(nn.modules.padding._ReplicationPadNd, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.padding._ReplicationPadNd.__init__(self, **self.param_dict) - - + + class BCELoss(nn.modules.loss.BCELoss, FateTorch): - - def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', **kwargs): + + def __init__( + self, + weight=None, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['weight'] = weight self.param_dict['size_average'] = size_average @@ -1471,11 +2024,18 @@ def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean' self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) nn.modules.loss.BCELoss.__init__(self, **self.param_dict) - - + + class BCEWithLogitsLoss(nn.modules.loss.BCEWithLogitsLoss, FateTorch): - - def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None, **kwargs): + + def __init__( + self, + weight=None, + size_average=None, + reduce=None, + reduction='mean', + pos_weight=None, + **kwargs): FateTorch.__init__(self) self.param_dict['weight'] = weight self.param_dict['size_average'] = size_average @@ -1484,22 +2044,33 @@ def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean' self.param_dict['pos_weight'] = pos_weight self.param_dict.update(kwargs) nn.modules.loss.BCEWithLogitsLoss.__init__(self, **self.param_dict) - - + + class CTCLoss(nn.modules.loss.CTCLoss, FateTorch): - - def __init__(self, blank=0, reduction='mean', zero_infinity=False, **kwargs): + + def __init__( + self, + blank=0, + reduction='mean', + zero_infinity=False, + **kwargs): FateTorch.__init__(self) self.param_dict['blank'] = blank self.param_dict['reduction'] = reduction self.param_dict['zero_infinity'] = zero_infinity self.param_dict.update(kwargs) nn.modules.loss.CTCLoss.__init__(self, **self.param_dict) - - + + class CosineEmbeddingLoss(nn.modules.loss.CosineEmbeddingLoss, FateTorch): - - def __init__(self, margin=0.0, size_average=None, reduce=None, reduction='mean', **kwargs): + + def __init__( + self, + margin=0.0, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['margin'] = margin self.param_dict['size_average'] = size_average @@ -1507,11 +2078,19 @@ def __init__(self, margin=0.0, size_average=None, reduce=None, reduction='mean', self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) nn.modules.loss.CosineEmbeddingLoss.__init__(self, **self.param_dict) - - + + class CrossEntropyLoss(nn.modules.loss.CrossEntropyLoss, FateTorch): - - def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0, **kwargs): + + def __init__( + self, + weight=None, + size_average=None, + ignore_index=-100, + reduce=None, + reduction='mean', + label_smoothing=0.0, + **kwargs): FateTorch.__init__(self) self.param_dict['weight'] = weight self.param_dict['size_average'] = size_average @@ -1521,19 +2100,25 @@ def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=Non self.param_dict['label_smoothing'] = label_smoothing self.param_dict.update(kwargs) nn.modules.loss.CrossEntropyLoss.__init__(self, **self.param_dict) - - + + class GaussianNLLLoss(nn.modules.loss.GaussianNLLLoss, FateTorch): - + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) nn.modules.loss.GaussianNLLLoss.__init__(self, **self.param_dict) - - + + class HingeEmbeddingLoss(nn.modules.loss.HingeEmbeddingLoss, FateTorch): - - def __init__(self, margin=1.0, size_average=None, reduce=None, reduction='mean', **kwargs): + + def __init__( + self, + margin=1.0, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['margin'] = margin self.param_dict['size_average'] = size_average @@ -1541,21 +2126,27 @@ def __init__(self, margin=1.0, size_average=None, reduce=None, reduction='mean', self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) nn.modules.loss.HingeEmbeddingLoss.__init__(self, **self.param_dict) - - + + class HuberLoss(nn.modules.loss.HuberLoss, FateTorch): - + def __init__(self, reduction='mean', delta=1.0, **kwargs): FateTorch.__init__(self) self.param_dict['reduction'] = reduction self.param_dict['delta'] = delta self.param_dict.update(kwargs) nn.modules.loss.HuberLoss.__init__(self, **self.param_dict) - - + + class KLDivLoss(nn.modules.loss.KLDivLoss, FateTorch): - - def __init__(self, size_average=None, reduce=None, reduction='mean', log_target=False, **kwargs): + + def __init__( + self, + size_average=None, + reduce=None, + reduction='mean', + log_target=False, + **kwargs): FateTorch.__init__(self) self.param_dict['size_average'] = size_average self.param_dict['reduce'] = reduce @@ -1563,33 +2154,49 @@ def __init__(self, size_average=None, reduce=None, reduction='mean', log_target= self.param_dict['log_target'] = log_target self.param_dict.update(kwargs) nn.modules.loss.KLDivLoss.__init__(self, **self.param_dict) - - + + class L1Loss(nn.modules.loss.L1Loss, FateTorch): - - def __init__(self, size_average=None, reduce=None, reduction='mean', **kwargs): + + def __init__( + self, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['size_average'] = size_average self.param_dict['reduce'] = reduce self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) nn.modules.loss.L1Loss.__init__(self, **self.param_dict) - - + + class MSELoss(nn.modules.loss.MSELoss, FateTorch): - - def __init__(self, size_average=None, reduce=None, reduction='mean', **kwargs): + + def __init__( + self, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['size_average'] = size_average self.param_dict['reduce'] = reduce self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) nn.modules.loss.MSELoss.__init__(self, **self.param_dict) - - + + class MarginRankingLoss(nn.modules.loss.MarginRankingLoss, FateTorch): - - def __init__(self, margin=0.0, size_average=None, reduce=None, reduction='mean', **kwargs): + + def __init__( + self, + margin=0.0, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['margin'] = margin self.param_dict['size_average'] = size_average @@ -1597,34 +2204,56 @@ def __init__(self, margin=0.0, size_average=None, reduce=None, reduction='mean', self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) nn.modules.loss.MarginRankingLoss.__init__(self, **self.param_dict) - - + + class MultiLabelMarginLoss(nn.modules.loss.MultiLabelMarginLoss, FateTorch): - - def __init__(self, size_average=None, reduce=None, reduction='mean', **kwargs): + + def __init__( + self, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['size_average'] = size_average self.param_dict['reduce'] = reduce self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) nn.modules.loss.MultiLabelMarginLoss.__init__(self, **self.param_dict) - - -class MultiLabelSoftMarginLoss(nn.modules.loss.MultiLabelSoftMarginLoss, FateTorch): - - def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', **kwargs): + + +class MultiLabelSoftMarginLoss( + nn.modules.loss.MultiLabelSoftMarginLoss, + FateTorch): + + def __init__( + self, + weight=None, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['weight'] = weight self.param_dict['size_average'] = size_average self.param_dict['reduce'] = reduce self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) - nn.modules.loss.MultiLabelSoftMarginLoss.__init__(self, **self.param_dict) - - + nn.modules.loss.MultiLabelSoftMarginLoss.__init__( + self, **self.param_dict) + + class MultiMarginLoss(nn.modules.loss.MultiMarginLoss, FateTorch): - - def __init__(self, p=1, margin=1.0, weight=None, size_average=None, reduce=None, reduction='mean', **kwargs): + + def __init__( + self, + p=1, + margin=1.0, + weight=None, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['p'] = p self.param_dict['margin'] = margin @@ -1634,11 +2263,18 @@ def __init__(self, p=1, margin=1.0, weight=None, size_average=None, reduce=None, self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) nn.modules.loss.MultiMarginLoss.__init__(self, **self.param_dict) - - + + class NLLLoss(nn.modules.loss.NLLLoss, FateTorch): - - def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', **kwargs): + + def __init__( + self, + weight=None, + size_average=None, + ignore_index=-100, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['weight'] = weight self.param_dict['size_average'] = size_average @@ -1647,11 +2283,18 @@ def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=Non self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) nn.modules.loss.NLLLoss.__init__(self, **self.param_dict) - - + + class NLLLoss2d(nn.modules.loss.NLLLoss2d, FateTorch): - - def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', **kwargs): + + def __init__( + self, + weight=None, + size_average=None, + ignore_index=-100, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['weight'] = weight self.param_dict['size_average'] = size_average @@ -1660,11 +2303,19 @@ def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=Non self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) nn.modules.loss.NLLLoss2d.__init__(self, **self.param_dict) - - + + class PoissonNLLLoss(nn.modules.loss.PoissonNLLLoss, FateTorch): - - def __init__(self, log_input=True, full=False, size_average=None, eps=1e-08, reduce=None, reduction='mean', **kwargs): + + def __init__( + self, + log_input=True, + full=False, + size_average=None, + eps=1e-08, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['log_input'] = log_input self.param_dict['full'] = full @@ -1674,11 +2325,17 @@ def __init__(self, log_input=True, full=False, size_average=None, eps=1e-08, red self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) nn.modules.loss.PoissonNLLLoss.__init__(self, **self.param_dict) - - + + class SmoothL1Loss(nn.modules.loss.SmoothL1Loss, FateTorch): - - def __init__(self, size_average=None, reduce=None, reduction='mean', beta=1.0, **kwargs): + + def __init__( + self, + size_average=None, + reduce=None, + reduction='mean', + beta=1.0, + **kwargs): FateTorch.__init__(self) self.param_dict['size_average'] = size_average self.param_dict['reduce'] = reduce @@ -1686,22 +2343,36 @@ def __init__(self, size_average=None, reduce=None, reduction='mean', beta=1.0, * self.param_dict['beta'] = beta self.param_dict.update(kwargs) nn.modules.loss.SmoothL1Loss.__init__(self, **self.param_dict) - - + + class SoftMarginLoss(nn.modules.loss.SoftMarginLoss, FateTorch): - - def __init__(self, size_average=None, reduce=None, reduction='mean', **kwargs): + + def __init__( + self, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['size_average'] = size_average self.param_dict['reduce'] = reduce self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) nn.modules.loss.SoftMarginLoss.__init__(self, **self.param_dict) - - + + class TripletMarginLoss(nn.modules.loss.TripletMarginLoss, FateTorch): - - def __init__(self, margin=1.0, p=2.0, eps=1e-06, swap=False, size_average=None, reduce=None, reduction='mean', **kwargs): + + def __init__( + self, + margin=1.0, + p=2.0, + eps=1e-06, + swap=False, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['margin'] = margin self.param_dict['p'] = p @@ -1712,30 +2383,44 @@ def __init__(self, margin=1.0, p=2.0, eps=1e-06, swap=False, size_average=None, self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) nn.modules.loss.TripletMarginLoss.__init__(self, **self.param_dict) - - -class TripletMarginWithDistanceLoss(nn.modules.loss.TripletMarginWithDistanceLoss, FateTorch): - + + +class TripletMarginWithDistanceLoss( + nn.modules.loss.TripletMarginWithDistanceLoss, + FateTorch): + def __init__(self, **kwargs): FateTorch.__init__(self) self.param_dict.update(kwargs) - nn.modules.loss.TripletMarginWithDistanceLoss.__init__(self, **self.param_dict) - - + nn.modules.loss.TripletMarginWithDistanceLoss.__init__( + self, **self.param_dict) + + class _Loss(nn.modules.loss._Loss, FateTorch): - - def __init__(self, size_average=None, reduce=None, reduction='mean', **kwargs): + + def __init__( + self, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['size_average'] = size_average self.param_dict['reduce'] = reduce self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) nn.modules.loss._Loss.__init__(self, **self.param_dict) - - + + class _WeightedLoss(nn.modules.loss._WeightedLoss, FateTorch): - - def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', **kwargs): + + def __init__( + self, + weight=None, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): FateTorch.__init__(self) self.param_dict['weight'] = weight self.param_dict['size_average'] = size_average @@ -1743,5 +2428,3 @@ def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean' self.param_dict['reduction'] = reduction self.param_dict.update(kwargs) nn.modules.loss._WeightedLoss.__init__(self, **self.param_dict) - - \ No newline at end of file diff --git a/python/fate/components/components/nn/fate_torch/optim.py b/python/fate/components/components/nn/fate_torch/optim.py index 005cfb7651..abd59df5c5 100644 --- a/python/fate/components/components/nn/fate_torch/optim.py +++ b/python/fate/components/components/nn/fate_torch/optim.py @@ -3,8 +3,18 @@ class ASGD(optim.ASGD, FateTorchOptimizer): - - def __init__(self, params=None, lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0, foreach=None, maximize=False, ): + + def __init__( + self, + params=None, + lr=0.01, + lambd=0.0001, + alpha=0.75, + t0=1000000.0, + weight_decay=0, + foreach=None, + maximize=False, + ): FateTorchOptimizer.__init__(self) self.param_dict['lr'] = lr self.param_dict['lambd'] = lambd @@ -27,14 +37,22 @@ def __init__(self, params=None, lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, def __repr__(self): try: return type(self).__bases__[0].__repr__(self) - except: - return 'Optimizer ASGD without initiated parameters'.format(type(self).__name__) + except BaseException: + return 'Optimizer ASGD without initiated parameters'.format( + type(self).__name__) + - - class Adadelta(optim.Adadelta, FateTorchOptimizer): - - def __init__(self, params=None, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0, foreach=None, ): + + def __init__( + self, + params=None, + lr=1.0, + rho=0.9, + eps=1e-06, + weight_decay=0, + foreach=None, + ): FateTorchOptimizer.__init__(self) self.param_dict['lr'] = lr self.param_dict['rho'] = rho @@ -55,14 +73,23 @@ def __init__(self, params=None, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0, fore def __repr__(self): try: return type(self).__bases__[0].__repr__(self) - except: - return 'Optimizer Adadelta without initiated parameters'.format(type(self).__name__) + except BaseException: + return 'Optimizer Adadelta without initiated parameters'.format( + type(self).__name__) + - - class Adagrad(optim.Adagrad, FateTorchOptimizer): - - def __init__(self, params=None, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, foreach=None, ): + + def __init__( + self, + params=None, + lr=0.01, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + eps=1e-10, + foreach=None, + ): FateTorchOptimizer.__init__(self) self.param_dict['lr'] = lr self.param_dict['lr_decay'] = lr_decay @@ -84,14 +111,24 @@ def __init__(self, params=None, lr=0.01, lr_decay=0, weight_decay=0, initial_acc def __repr__(self): try: return type(self).__bases__[0].__repr__(self) - except: - return 'Optimizer Adagrad without initiated parameters'.format(type(self).__name__) + except BaseException: + return 'Optimizer Adagrad without initiated parameters'.format( + type(self).__name__) + - - class Adam(optim.Adam, FateTorchOptimizer): - - def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False, ): + + def __init__( + self, + params=None, + lr=0.001, + betas=( + 0.9, + 0.999), + eps=1e-08, + weight_decay=0, + amsgrad=False, + ): FateTorchOptimizer.__init__(self) self.param_dict['lr'] = lr self.param_dict['betas'] = betas @@ -112,14 +149,24 @@ def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_ def __repr__(self): try: return type(self).__bases__[0].__repr__(self) - except: - return 'Optimizer Adam without initiated parameters'.format(type(self).__name__) + except BaseException: + return 'Optimizer Adam without initiated parameters'.format( + type(self).__name__) + - - class AdamW(optim.AdamW, FateTorchOptimizer): - - def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False, ): + + def __init__( + self, + params=None, + lr=0.001, + betas=( + 0.9, + 0.999), + eps=1e-08, + weight_decay=0.01, + amsgrad=False, + ): FateTorchOptimizer.__init__(self) self.param_dict['lr'] = lr self.param_dict['betas'] = betas @@ -140,14 +187,24 @@ def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_ def __repr__(self): try: return type(self).__bases__[0].__repr__(self) - except: - return 'Optimizer AdamW without initiated parameters'.format(type(self).__name__) + except BaseException: + return 'Optimizer AdamW without initiated parameters'.format( + type(self).__name__) + - - class Adamax(optim.Adamax, FateTorchOptimizer): - - def __init__(self, params=None, lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, foreach=None, ): + + def __init__( + self, + params=None, + lr=0.002, + betas=( + 0.9, + 0.999), + eps=1e-08, + weight_decay=0, + foreach=None, + ): FateTorchOptimizer.__init__(self) self.param_dict['lr'] = lr self.param_dict['betas'] = betas @@ -168,14 +225,24 @@ def __init__(self, params=None, lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_ def __repr__(self): try: return type(self).__bases__[0].__repr__(self) - except: - return 'Optimizer Adamax without initiated parameters'.format(type(self).__name__) + except BaseException: + return 'Optimizer Adamax without initiated parameters'.format( + type(self).__name__) + - - class LBFGS(optim.LBFGS, FateTorchOptimizer): - - def __init__(self, params=None, lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn=None, ): + + def __init__( + self, + params=None, + lr=1, + max_iter=20, + max_eval=None, + tolerance_grad=1e-07, + tolerance_change=1e-09, + history_size=100, + line_search_fn=None, + ): FateTorchOptimizer.__init__(self) self.param_dict['lr'] = lr self.param_dict['max_iter'] = max_iter @@ -198,14 +265,25 @@ def __init__(self, params=None, lr=1, max_iter=20, max_eval=None, tolerance_grad def __repr__(self): try: return type(self).__bases__[0].__repr__(self) - except: - return 'Optimizer LBFGS without initiated parameters'.format(type(self).__name__) + except BaseException: + return 'Optimizer LBFGS without initiated parameters'.format( + type(self).__name__) + - - class NAdam(optim.NAdam, FateTorchOptimizer): - - def __init__(self, params=None, lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, momentum_decay=0.004, foreach=None, ): + + def __init__( + self, + params=None, + lr=0.002, + betas=( + 0.9, + 0.999), + eps=1e-08, + weight_decay=0, + momentum_decay=0.004, + foreach=None, + ): FateTorchOptimizer.__init__(self) self.param_dict['lr'] = lr self.param_dict['betas'] = betas @@ -227,14 +305,24 @@ def __init__(self, params=None, lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_ def __repr__(self): try: return type(self).__bases__[0].__repr__(self) - except: - return 'Optimizer NAdam without initiated parameters'.format(type(self).__name__) + except BaseException: + return 'Optimizer NAdam without initiated parameters'.format( + type(self).__name__) + - - class RAdam(optim.RAdam, FateTorchOptimizer): - - def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, foreach=None, ): + + def __init__( + self, + params=None, + lr=0.001, + betas=( + 0.9, + 0.999), + eps=1e-08, + weight_decay=0, + foreach=None, + ): FateTorchOptimizer.__init__(self) self.param_dict['lr'] = lr self.param_dict['betas'] = betas @@ -255,14 +343,26 @@ def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_ def __repr__(self): try: return type(self).__bases__[0].__repr__(self) - except: - return 'Optimizer RAdam without initiated parameters'.format(type(self).__name__) + except BaseException: + return 'Optimizer RAdam without initiated parameters'.format( + type(self).__name__) + - - class RMSprop(optim.RMSprop, FateTorchOptimizer): - - def __init__(self, params=None, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False, foreach=None, maximize=False, differentiable=False, ): + + def __init__( + self, + params=None, + lr=0.01, + alpha=0.99, + eps=1e-08, + weight_decay=0, + momentum=0, + centered=False, + foreach=None, + maximize=False, + differentiable=False, + ): FateTorchOptimizer.__init__(self) self.param_dict['lr'] = lr self.param_dict['alpha'] = alpha @@ -287,14 +387,17 @@ def __init__(self, params=None, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, def __repr__(self): try: return type(self).__bases__[0].__repr__(self) - except: - return 'Optimizer RMSprop without initiated parameters'.format(type(self).__name__) + except BaseException: + return 'Optimizer RMSprop without initiated parameters'.format( + type(self).__name__) + - - class Rprop(optim.Rprop, FateTorchOptimizer): - - def __init__(self, params=None, lr=0.01, etas=(0.5, 1.2), step_sizes=(1e-06, 50), foreach=None, maximize=False, ): + + def __init__( + self, params=None, lr=0.01, etas=( + 0.5, 1.2), step_sizes=( + 1e-06, 50), foreach=None, maximize=False, ): FateTorchOptimizer.__init__(self) self.param_dict['lr'] = lr self.param_dict['etas'] = etas @@ -315,14 +418,22 @@ def __init__(self, params=None, lr=0.01, etas=(0.5, 1.2), step_sizes=(1e-06, 50) def __repr__(self): try: return type(self).__bases__[0].__repr__(self) - except: - return 'Optimizer Rprop without initiated parameters'.format(type(self).__name__) + except BaseException: + return 'Optimizer Rprop without initiated parameters'.format( + type(self).__name__) + - - class SGD(optim.SGD, FateTorchOptimizer): - - def __init__(self, lr, params=None, momentum=0, dampening=0, weight_decay=0, nesterov=False, ): + + def __init__( + self, + lr, + params=None, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + ): FateTorchOptimizer.__init__(self) self.param_dict['lr'] = lr self.param_dict['momentum'] = momentum @@ -343,14 +454,23 @@ def __init__(self, lr, params=None, momentum=0, dampening=0, weight_decay=0, nes def __repr__(self): try: return type(self).__bases__[0].__repr__(self) - except: - return 'Optimizer SGD without initiated parameters'.format(type(self).__name__) + except BaseException: + return 'Optimizer SGD without initiated parameters'.format( + type(self).__name__) + - - class SparseAdam(optim.SparseAdam, FateTorchOptimizer): - - def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08, maximize=False, ): + + def __init__( + self, + params=None, + lr=0.001, + betas=( + 0.9, + 0.999), + eps=1e-08, + maximize=False, + ): FateTorchOptimizer.__init__(self) self.param_dict['lr'] = lr self.param_dict['betas'] = betas @@ -370,8 +490,6 @@ def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08, maximiz def __repr__(self): try: return type(self).__bases__[0].__repr__(self) - except: - return 'Optimizer SparseAdam without initiated parameters'.format(type(self).__name__) - - - \ No newline at end of file + except BaseException: + return 'Optimizer SparseAdam without initiated parameters'.format( + type(self).__name__) diff --git a/python/fate/components/components/nn/loader.py b/python/fate/components/components/nn/loader.py index 110d00d48e..cbbf281525 100644 --- a/python/fate/components/components/nn/loader.py +++ b/python/fate/components/components/nn/loader.py @@ -15,6 +15,7 @@ class _Source(object): SOURCE_FILE = 'source.yaml' + def is_path(s): return os.path.exists(s) @@ -26,6 +27,7 @@ def load_source(): source = yaml.safe_load(f) return source + class AbstractLoader(ABC): @abstractmethod def __init__(self, module_name, item_name, source=None): @@ -49,7 +51,7 @@ def to_dict(self): class Loader(AbstractLoader): - + def __init__(self, module_name, item_name, source=None, **kwargs): self.item_name = item_name @@ -63,7 +65,9 @@ def __init__(self, module_name, item_name, source=None, **kwargs): if self.source in source_dict: self.source_path = source_dict[self.source] else: - raise ValueError('source name {} is not found in the source.yaml file. Please check the source name.'.format(self.source)) + raise ValueError( + 'source name {} is not found in the source.yaml file. Please check the source name.'.format( + self.source)) elif source is None: self.module_name = module_name @@ -93,15 +97,21 @@ def _load_item(self): suggestion = self._find_similar_module_names() print('suggestion is {}'.format(suggestion)) if suggestion: - raise ValueError("Module: {} not found in the import path. Do you mean {}?".format(self.module_name, suggestion)) + raise ValueError( + "Module: {} not found in the import path. Do you mean {}?".format( + self.module_name, suggestion)) else: - raise ValueError("Module: {} not found in the import path.".format(self.module_name)) + raise ValueError( + "Module: {} not found in the import path.".format( + self.module_name)) module = importlib.import_module(self.module_name) - + item = getattr(module, self.item_name, None) if item is None: - raise ValueError("Item: {} not found in module: {}.".format(self.item_name, self.module_name)) + raise ValueError( + "Item: {} not found in module: {}.".format( + self.item_name, self.module_name)) if self.source_path is not None: sys.path.remove(self.source_path) @@ -135,8 +145,8 @@ def from_json(json_str): @staticmethod def from_dict(data_dict): - return Loader(module_name=data_dict['module_name'], - item_name=data_dict['item_name'], + return Loader(module_name=data_dict['module_name'], + item_name=data_dict['item_name'], source=data_dict.get('source', None), **data_dict.get('kwargs', {}) ) @@ -145,19 +155,40 @@ def from_dict(data_dict): class ModelLoader(Loader): def __init__(self, module_name, item_name, source=None, **kwargs): if source is None: - module_name = f'{_Source.MODEL_ZOO}.{module_name}' # add prefix for moduele loader - super(ModelLoader, self).__init__(module_name, item_name, source, **kwargs) + # add prefix for moduele loader + module_name = f'{_Source.MODEL_ZOO}.{module_name}' + super( + ModelLoader, + self).__init__( + module_name, + item_name, + source, + **kwargs) class DatasetLoader(Loader): def __init__(self, module_name, item_name, source=None, **kwargs): if source is None: - module_name = f'{_Source.DATASET}.{module_name}' # add prefix for moduele loader - super(DatasetLoader, self).__init__(module_name, item_name, source, **kwargs) + # add prefix for moduele loader + module_name = f'{_Source.DATASET}.{module_name}' + super( + DatasetLoader, + self).__init__( + module_name, + item_name, + source, + **kwargs) class CustFuncLoader(Loader): def __init__(self, module_name, item_name, source=None, **kwargs): if source is None: - module_name = f'{_Source.CUST_FUNC}.{module_name}' # add prefix for moduele loader - super(CustFuncLoader, self).__init__(module_name, item_name, source, **kwargs) \ No newline at end of file + # add prefix for moduele loader + module_name = f'{_Source.CUST_FUNC}.{module_name}' + super( + CustFuncLoader, + self).__init__( + module_name, + item_name, + source, + **kwargs) diff --git a/python/fate/components/components/nn/nn_runner.py b/python/fate/components/components/nn/nn_runner.py index 2167b0c29c..9da71f4ed7 100644 --- a/python/fate/components/components/nn/nn_runner.py +++ b/python/fate/components/components/nn/nn_runner.py @@ -1,5 +1,5 @@ import numpy as np -import torch +import torch import pandas as pd from typing import Union, Optional, Literal from fate.components.core import Role @@ -17,17 +17,18 @@ logger = logging.getLogger(__name__) -def _convert_to_numpy_array(data: Union[pd.Series, pd.DataFrame, np.ndarray, torch.Tensor]) -> np.ndarray: +def _convert_to_numpy_array( + data: Union[pd.Series, pd.DataFrame, np.ndarray, torch.Tensor]) -> np.ndarray: if isinstance(data, pd.Series) or isinstance(data, pd.DataFrame): return data.to_numpy() elif isinstance(data, torch.Tensor): return data.cpu().numpy() else: return np.array(data) - + def task_type_infer(predict_result, true_label): - + pred_shape = predict_result.shape if true_label.max() == 1.0 and true_label.min() == 0.0: @@ -50,7 +51,7 @@ def task_type_infer(predict_result, true_label): class NNRunner(object): def __init__(self) -> None: - + self._role = None self._party_id = None self._ctx: Context = None @@ -68,14 +69,14 @@ def set_role(self, role: Role): def is_client(self) -> bool: return self._role.is_guest or self._role.is_host - + def is_server(self) -> bool: return self._role.is_arbiter - + def set_party_id(self, party_id: int): assert isinstance(self._party_id, int) self._party_id = party_id - + def get_fateboard_tracker(self): pass @@ -92,7 +93,7 @@ def get_nn_output_dataframe( task_type: Literal['binary', 'multi', 'regression', 'others'] = None, threshold: float = 0.5, classes: list = None - )-> DataFrame: + ) -> DataFrame: """ Constructs a FATE DataFrame from predictions and labels. This Dataframe is able to flow through FATE components. @@ -107,15 +108,17 @@ def get_nn_output_dataframe( dataframe_format (Literal['default', 'fate_std'], optional): Output format of the resulting DataFrame. If 'default', simply combines labels and predictions into a DataFrame. If 'fate_std', organizes output according to the FATE framework's format. Defaults to 'default'. task_type (Literal['binary', 'multi', 'regression', 'others'], optional): This parameter is only needed when dataframe_format is 'fate_std'. Defaults to None. - The type of machine learning task, which can be 'binary', 'multi', 'regression', or 'others'. + The type of machine learning task, which can be 'binary', 'multi', 'regression', or 'others'. threshold (float, optional): This parameter is only needed when dataframe_format is 'fate_std' and task_type is 'binary'. Defaults to 0.5. classes (list, optional): This parameter is only needed when dataframe_format is 'fate_std'. List of classes. Returns: DataFrame: A DataFrame that contains the neural network's predictions and the true labels, possibly along with match IDs and sample IDs, formatted according to the specified format. """ # check parameters - assert task_type in ['binary', 'multi', 'regression', 'others'], f"task_type {task_type} is not supported" - assert dataframe_format in ['default', 'fate_std'], f"dataframe_format {dataframe_format} is not supported" + assert task_type in ['binary', 'multi', 'regression', + 'others'], f"task_type {task_type} is not supported" + assert dataframe_format in [ + 'default', 'fate_std'], f"dataframe_format {dataframe_format} is not supported" if match_id_name is None: match_id_name = 'id' @@ -129,35 +132,53 @@ def get_nn_output_dataframe( predictions = _convert_to_numpy_array(predictions) labels = _convert_to_numpy_array(labels) - assert len(predictions) == len(labels), f"predictions length {len(predictions)} != labels length {len(labels)}" - + assert len(predictions) == len( + labels), f"predictions length {len(predictions)} != labels length {len(labels)}" + # check match ids if match_ids is not None: match_ids = _convert_to_numpy_array(match_ids).flatten() else: - logger.info("match_ids is not provided, will auto generate match_ids") - match_ids = np.array([i for i in range(len(predictions))]).flatten() - + logger.info( + "match_ids is not provided, will auto generate match_ids") + match_ids = np.array( + [i for i in range(len(predictions))]).flatten() + # check sample ids if sample_ids is not None: sample_ids = _convert_to_numpy_array(sample_ids).flatten() else: - logger.info("sample_ids is not provided, will auto generate sample_ids") - sample_ids = np.array([i for i in range(len(predictions))]).flatten() + logger.info( + "sample_ids is not provided, will auto generate sample_ids") + sample_ids = np.array( + [i for i in range(len(predictions))]).flatten() - assert len(match_ids) == len(predictions), f"match_ids length {len(match_ids)} != predictions length {len(predictions)}" - assert len(sample_ids) == len(predictions), f"sample_ids length {len(sample_ids)} != predictions length {len(predictions)}" + assert len(match_ids) == len( + predictions), f"match_ids length {len(match_ids)} != predictions length {len(predictions)}" + assert len(sample_ids) == len( + predictions), f"sample_ids length {len(sample_ids)} != predictions length {len(predictions)}" # match id name and sample id name must be str - assert isinstance(match_id_name, str), f"match_id_name must be str, but got {type(match_id_name)}" - assert isinstance(sample_id_name, str), f"sample_id_name must be str, but got {type(sample_id_name)}" - - if dataframe_format == 'default' or (dataframe_format == 'fate_std' and task_type == 'others'): - df = pd.DataFrame({'label': labels.to_list(), 'predict': predictions.to_list(), match_id_name: match_ids.to_list(), sample_id_name: sample_ids.to_list()}) + assert isinstance( + match_id_name, str), f"match_id_name must be str, but got {type(match_id_name)}" + assert isinstance( + sample_id_name, str), f"sample_id_name must be str, but got {type(sample_id_name)}" + + if dataframe_format == 'default' or ( + dataframe_format == 'fate_std' and task_type == 'others'): + df = pd.DataFrame({'label': labels.to_list(), + 'predict': predictions.to_list(), + match_id_name: match_ids.to_list(), + sample_id_name: sample_ids.to_list()}) df = to_fate_df(ctx, sample_id_name, match_id_name, df) return df elif dataframe_format == 'fate_std' and task_type in ['binary', 'multi', 'regression']: - df = std_output_df(task_type, predictions, labels, threshold, classes) + df = std_output_df( + task_type, + predictions, + labels, + threshold, + classes) match_id_df = pd.DataFrame() match_id_df[match_id_name] = match_ids sample_id_df = pd.DataFrame() @@ -166,8 +187,13 @@ def get_nn_output_dataframe( df = to_fate_df(ctx, sample_id_name, match_id_name, df) return df - - def train(self, train_data: Optional[Union[str, DataFrame]] = None, validate_data: Optional[Union[str, DataFrame]] = None, output_dir: str = None, saved_model_path: str = None) -> None: + def train(self, + train_data: Optional[Union[str, + DataFrame]] = None, + validate_data: Optional[Union[str, + DataFrame]] = None, + output_dir: str = None, + saved_model_path: str = None) -> None: """ Train interface. @@ -179,7 +205,11 @@ def train(self, train_data: Optional[Union[str, DataFrame]] = None, validate_dat """ pass - def predict(self, test_data: Optional[Union[str, DataFrame]] = None, output_dir: str = None, saved_model_path: str = None) -> DataFrame: + def predict(self, + test_data: Optional[Union[str, + DataFrame]] = None, + output_dir: str = None, + saved_model_path: str = None) -> DataFrame: """ Predict interface. @@ -188,4 +218,3 @@ def predict(self, test_data: Optional[Union[str, DataFrame]] = None, output_dir: output_dir (str, optional): The path to the directory where the trained model should be saved. If this class is running in the fate pipeline, this path will provided by FATE framework. saved_model_path (str, optional): The path to the saved model that should be loaded before training starts.If this class is running in the fate pipeline, this path will provided by FATE framework. """ - diff --git a/python/fate/components/components/nn/runner/default_runner.py b/python/fate/components/components/nn/runner/default_runner.py index cf0a34caa8..6df95203b3 100644 --- a/python/fate/components/components/nn/runner/default_runner.py +++ b/python/fate/components/components/nn/runner/default_runner.py @@ -27,14 +27,17 @@ def load_model_dict_from_path(path): # Ensure that the path is a string - assert isinstance(path, str), "Path must be a string, but got {}".format(type(path)) + assert isinstance( + path, str), "Path must be a string, but got {}".format( + type(path)) # Append the filename to the path model_path = os.path.join(path, 'pytorch_model.bin') # Check if the file exists if not os.path.isfile(model_path): - raise FileNotFoundError(f"No 'pytorch_model.bin' file found at {model_path}, no saved model found") + raise FileNotFoundError( + f"No 'pytorch_model.bin' file found at {model_path}, no saved model found") # Load the state dict from the specified path model_dict = t.load(model_path) @@ -44,8 +47,9 @@ def load_model_dict_from_path(path): def dir_warning(train_args): if 'output_dir' in train_args or 'logging_dir' in train_args or 'resume_from_checkpoint' in train_args: - logger.warning("The output_dir, logging_dir, and resume_from_checkpoint arguments are not supported in the " - "DefaultRunner when running the Pipeline. These arguments will be replaced by FATE provided paths.") + logger.warning( + "The output_dir, logging_dir, and resume_from_checkpoint arguments are not supported in the " + "DefaultRunner when running the Pipeline. These arguments will be replaced by FATE provided paths.") class SetupReturn: @@ -54,38 +58,55 @@ class SetupReturn: """ def __init__(self, - trainer: Union[Type[FedTrainerClient], Type[FedTrainerServer]] = None, - model: Type[nn.Module] = None, - optimizer: Type[optim.Optimizer] = None, - loss: Callable = None, - scheduler: Type[_LRScheduler] = None, - train_args: TrainingArguments = None, - fed_args: FedArguments = None, - data_collator: Callable = None) -> None: - - if trainer is not None and not (issubclass(type(trainer), FedTrainerClient) or issubclass(type(trainer), FedTrainerServer)): - raise TypeError(f"SetupReturn Error: trainer must be a subclass of either FedTrainerClient or FedTrainerServer but got {type(trainer)}") - + trainer: Union[Type[FedTrainerClient], + Type[FedTrainerServer]] = None, + model: Type[nn.Module] = None, + optimizer: Type[optim.Optimizer] = None, + loss: Callable = None, + scheduler: Type[_LRScheduler] = None, + train_args: TrainingArguments = None, + fed_args: FedArguments = None, + data_collator: Callable = None) -> None: + + if trainer is not None and not ( + issubclass( + type(trainer), + FedTrainerClient) or issubclass( + type(trainer), + FedTrainerServer)): + raise TypeError( + f"SetupReturn Error: trainer must be a subclass of either FedTrainerClient or FedTrainerServer but got {type(trainer)}") + if model is not None and not issubclass(type(model), nn.Module): - raise TypeError(f"SetupReturn Error: model must be a subclass of torch.nn.Module but got {type(model)}") - - if optimizer is not None and not issubclass(type(optimizer), optim.Optimizer): - raise TypeError(f"SetupReturn Error: optimizer must be a subclass of torch.optim.Optimizer but got {type(optimizer)}") - + raise TypeError( + f"SetupReturn Error: model must be a subclass of torch.nn.Module but got {type(model)}") + + if optimizer is not None and not issubclass( + type(optimizer), optim.Optimizer): + raise TypeError( + f"SetupReturn Error: optimizer must be a subclass of torch.optim.Optimizer but got {type(optimizer)}") + if loss is not None and not callable(loss): - raise TypeError(f"SetupReturn Error: loss must be callable but got {type(loss)}") - - if scheduler is not None and not issubclass(type(scheduler), _LRScheduler): - raise TypeError(f"SetupReturn Error: scheduler must be a subclass of torch.optim.lr_scheduler._LRScheduler but got {type(scheduler)}") - - if train_args is not None and not isinstance(train_args, TrainingArguments): - raise TypeError(f"SetupReturn Error: train_args must be an instance of TrainingArguments but got {type(train_args)}") - + raise TypeError( + f"SetupReturn Error: loss must be callable but got {type(loss)}") + + if scheduler is not None and not issubclass( + type(scheduler), _LRScheduler): + raise TypeError( + f"SetupReturn Error: scheduler must be a subclass of torch.optim.lr_scheduler._LRScheduler but got {type(scheduler)}") + + if train_args is not None and not isinstance( + train_args, TrainingArguments): + raise TypeError( + f"SetupReturn Error: train_args must be an instance of TrainingArguments but got {type(train_args)}") + if fed_args is not None and not isinstance(fed_args, FedArguments): - raise TypeError(f"SetupReturn Error: fed_args must be an instance of FedArguments but got {type(fed_args)}") - + raise TypeError( + f"SetupReturn Error: fed_args must be an instance of FedArguments but got {type(fed_args)}") + if data_collator is not None and not callable(data_collator): - raise TypeError(f"SetupReturn Error: data_collator must be callable but got {type(data_collator)}") + raise TypeError( + f"SetupReturn Error: data_collator must be callable but got {type(data_collator)}") self.trainer = trainer self.model = model @@ -110,9 +131,9 @@ def __repr__(self): class DefaultRunner(NNRunner): - def __init__(self, + def __init__(self, algo: str = 'fedavg', - model_conf: Optional[Dict] = None, + model_conf: Optional[Dict] = None, dataset_conf: Optional[Dict] = None, optimizer_conf: Optional[Dict] = None, training_args_conf: Optional[Dict] = None, @@ -120,11 +141,13 @@ def __init__(self, loss_conf: Optional[Dict] = None, data_collator_conf: Optional[Dict] = None, tokenizer_conf: Optional[Dict] = None, - task_type: Literal['binary', 'multi', 'regression', 'others'] = 'binary', + task_type: Literal['binary', + 'multi', + 'regression', + 'others'] = 'binary', threshold: float = 0.5, - local_mode: bool = False - ) -> None: - + local_mode: bool = False) -> None: + super().__init__() self.algo = algo self.model_conf = model_conf @@ -143,7 +166,8 @@ def __init__(self, if self.algo not in SUPPORTED_ALGO: raise ValueError('algo should be one of [fedavg]') if self.task_type not in ['binary', 'multi', 'regression', 'others']: - raise ValueError('task_type should be one of [binary, multi, regression, others]') + raise ValueError( + 'task_type should be one of [binary, multi, regression, others]') assert self.threshold >= 0 and self.threshold <= 1, 'threshold should be in [0, 1]' assert isinstance(self.local_mode, bool), 'local should be bool' @@ -158,13 +182,17 @@ def _loader_load_from_conf(self, conf, return_class=False): return Loader.from_dict(conf).call_item() def _prepare_data(self, data, data_name) -> SetupReturn: - + if data is None: return None if isinstance(data, DataFrame) and self.dataset_conf is None: - logger.info('Input data {} is FATE DataFrame and dataset conf is None, will automatically handle the input data'.format(data_name)) + logger.info( + 'Input data {} is FATE DataFrame and dataset conf is None, will automatically handle the input data'.format(data_name)) if self.task_type == consts.MULTI: - dataset = TableDataset(flatten_label=True, label_dtype='long', to_tensor=True) + dataset = TableDataset( + flatten_label=True, + label_dtype='long', + to_tensor=True) else: dataset = TableDataset(to_tensor=True) dataset.load(data) @@ -173,15 +201,24 @@ def _prepare_data(self, data, data_name) -> SetupReturn: if hasattr(dataset, 'load'): dataset.load(data) else: - raise ValueError(f"The dataset {dataset} lacks a load() method, which is required for data parsing in the DefaultRunner. \ + raise ValueError( + f"The dataset {dataset} lacks a load() method, which is required for data parsing in the DefaultRunner. \ Please implement this method in your dataset class. You can refer to the base class 'Dataset' in 'fate.ml.nn.dataset.base' \ for the necessary interfaces to implement.") - if dataset is not None and not issubclass(type(dataset), data_utils.Dataset): - raise TypeError(f"SetupReturn Error: {data_name}_set must be a subclass of torch.utils.data.Dataset but got {type(dataset)}") - + if dataset is not None and not issubclass( + type(dataset), data_utils.Dataset): + raise TypeError( + f"SetupReturn Error: {data_name}_set must be a subclass of torch.utils.data.Dataset but got {type(dataset)}") + return dataset - - def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None, stage='train'): + + def client_setup( + self, + train_set=None, + validate_set=None, + output_dir=None, + saved_model=None, + stage='train'): if stage == 'predict': self.local_mode = True @@ -190,11 +227,12 @@ def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved client_class: FedAVGCLient = FedAVG.client else: raise ValueError(f"algo {self.algo} not supported") - + ctx = self.get_context() model = self._loader_load_from_conf(self.model_conf) if model is None: - raise ValueError(f"model is None, cannot load model from conf {self.model_conf}") + raise ValueError( + f"model is None, cannot load model from conf {self.model_conf}") if output_dir is None: output_dir = './' @@ -205,7 +243,8 @@ def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved logger.info(f"loading model dict from {saved_model} to model done") if get_last_checkpoint(saved_model) is not None: resume_path = saved_model - logger.info(f"checkpoint detected, resume_path set to {resume_path}") + logger.info( + f"checkpoint detected, resume_path set to {resume_path}") else: resume_path = None @@ -223,21 +262,37 @@ def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved # args dir_warning(self.training_args_conf) training_args = TrainingArguments(**self.training_args_conf) - training_args.output_dir = output_dir # reset to default, saving to arbitrary path is not allowed in DefaultRunner + # reset to default, saving to arbitrary path is not allowed in + # DefaultRunner + training_args.output_dir = output_dir training_args.resume_from_checkpoint = resume_path # resume path fed_args = FedAVGArguments(**self.fed_args_conf) # prepare trainer - trainer = client_class(ctx=ctx, model=model, loss_fn=loss, - optimizer=optimizer, training_args=training_args, - fed_args=fed_args, data_collator=data_collator, - tokenizer=tokenizer, train_set=train_set, val_set=validate_set, local_mode=self.local_mode) - - return SetupReturn(trainer=trainer, model=model, optimizer=optimizer, loss=loss, - train_args=training_args, fed_args=fed_args, data_collator=data_collator) - + trainer = client_class( + ctx=ctx, + model=model, + loss_fn=loss, + optimizer=optimizer, + training_args=training_args, + fed_args=fed_args, + data_collator=data_collator, + tokenizer=tokenizer, + train_set=train_set, + val_set=validate_set, + local_mode=self.local_mode) + + return SetupReturn( + trainer=trainer, + model=model, + optimizer=optimizer, + loss=loss, + train_args=training_args, + fed_args=fed_args, + data_collator=data_collator) + def server_setup(self, stage='train'): - + if stage == 'predict': self.local_mode = True if self.algo == 'fedavg': @@ -247,13 +302,23 @@ def server_setup(self, stage='train'): ctx = self.get_context() trainer = server_class(ctx=ctx, local_mode=self.local_mode) return SetupReturn(trainer=trainer) - - def train(self, train_data: Optional[Union[str, DataFrame]] = None, validate_data: Optional[Union[str, DataFrame]] = None, output_dir: str = None, saved_model_path: str = None): - + + def train(self, + train_data: Optional[Union[str, + DataFrame]] = None, + validate_data: Optional[Union[str, + DataFrame]] = None, + output_dir: str = None, + saved_model_path: str = None): + if self.is_client(): train_set = self._prepare_data(train_data, 'train_data') validate_set = self._prepare_data(validate_data, 'val_data') - setup = self.client_setup(train_set=train_set, validate_set=validate_set, output_dir=output_dir, saved_model=saved_model_path) + setup = self.client_setup( + train_set=train_set, + validate_set=validate_set, + output_dir=output_dir, + saved_model=saved_model_path) trainer = setup['trainer'] self.trainer = trainer trainer.train() @@ -269,35 +334,50 @@ def _run_dataset_func(self, dataset, func_name): if hasattr(dataset, func_name): output = getattr(dataset, func_name)() if output is None: - logger.info(f'dataset {type(dataset)}: {func_name} returns None, this will influence the output of predict') + logger.info( + f'dataset {type(dataset)}: {func_name} returns None, this will influence the output of predict') return output else: - logger.info(f'dataset {type(dataset)} not implemented {func_name}, classes set to None, this will influence the output of predict') + logger.info( + f'dataset {type(dataset)} not implemented {func_name}, classes set to None, this will influence the output of predict') return None - def predict(self, test_data: Union[str, DataFrame], saved_model_path: str = None) -> Union[DataFrame, None]: - + def predict(self, + test_data: Union[str, + DataFrame], + saved_model_path: str = None) -> Union[DataFrame, + None]: + if self.is_client(): test_set = self._prepare_data(test_data, 'test_data') if self.trainer is not None: trainer = self.trainer logger.info('trainer found, skip setting up') else: - setup = self.client_setup(saved_model=saved_model_path, stage='predict') + setup = self.client_setup( + saved_model=saved_model_path, stage='predict') trainer = setup['trainer'] classes = self._run_dataset_func(test_set, 'get_classes') match_ids = self._run_dataset_func(test_set, 'get_match_ids') sample_ids = self._run_dataset_func(test_set, 'get_sample_ids') - match_id_name = self._run_dataset_func(test_set, 'get_match_id_name') - sample_id_name = self._run_dataset_func(test_set, 'get_sample_id_name') + match_id_name = self._run_dataset_func( + test_set, 'get_match_id_name') + sample_id_name = self._run_dataset_func( + test_set, 'get_sample_id_name') pred_rs = trainer.predict(test_set) - rs_df = self.get_nn_output_dataframe(self.get_context(), pred_rs.predictions, pred_rs.label_ids, match_ids, sample_ids, match_id_name=match_id_name, sample_id_name=sample_id_name, - dataframe_format='fate_std', task_type=self.task_type, classes=classes) + rs_df = self.get_nn_output_dataframe( + self.get_context(), + pred_rs.predictions, + pred_rs.label_ids, + match_ids, + sample_ids, + match_id_name=match_id_name, + sample_id_name=sample_id_name, + dataframe_format='fate_std', + task_type=self.task_type, + classes=classes) return rs_df else: # server not predict - return - - - + return diff --git a/python/fate/components/components/nn/runner/my_runner.py b/python/fate/components/components/nn/runner/my_runner.py index b2b4fbc1d2..63fdc047ae 100644 --- a/python/fate/components/components/nn/runner/my_runner.py +++ b/python/fate/components/components/nn/runner/my_runner.py @@ -7,7 +7,12 @@ class MyRunner(NNRunner): - def __init__(self, in_feat=30, epoch=10, learning_rate=0.01, batch_size=32) -> None: + def __init__( + self, + in_feat=30, + epoch=10, + learning_rate=0.01, + batch_size=32) -> None: super().__init__() self.in_feat = in_feat self.epoch = epoch @@ -15,7 +20,7 @@ def __init__(self, in_feat=30, epoch=10, learning_rate=0.01, batch_size=32) -> N self.batch_size = batch_size def setup(self, df=None): - + ctx = self.get_context() if self.is_client(): @@ -37,16 +42,19 @@ def setup(self, df=None): optimizer = t.optim.Adam(model.parameters(), lr=self.learning_rate) - train_arg = TrainingArguments(num_train_epochs=self.epoch, per_device_train_batch_size=self.batch_size, disable_tqdm=False) + train_arg = TrainingArguments( + num_train_epochs=self.epoch, + per_device_train_batch_size=self.batch_size, + disable_tqdm=False) fed_arg = FedAVGArguments() - return FedAVGCLient(ctx=ctx, model=model, optimizer=optimizer, - training_args=train_arg, fed_args=fed_arg, train_set=dataset, loss_fn=loss_fn), dataset + return FedAVGCLient(ctx=ctx, model=model, optimizer=optimizer, training_args=train_arg, + fed_args=fed_arg, train_set=dataset, loss_fn=loss_fn), dataset elif self.is_server(): return FedAVGServer(ctx=ctx) - + def train(self, input_data: NNInput = None): if self.is_client(): df = input_data.get('train_data') @@ -55,7 +63,7 @@ def train(self, input_data: NNInput = None): trainer = self.setup() trainer.train() - + def predict(self, input_data: NNInput = None): if self.is_client(): diff --git a/python/fate/components/components/nn/test/test_default_runner.py b/python/fate/components/components/nn/test/test_default_runner.py index 43a5978320..42c3a14bc3 100644 --- a/python/fate/components/components/nn/test/test_default_runner.py +++ b/python/fate/components/components/nn/test/test_default_runner.py @@ -18,39 +18,51 @@ logger.setLevel(logging.INFO) ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') ch.setFormatter(formatter) logger.addHandler(ch) computing = CSession() -ctx = Context( - "guest", - computing=computing, - federation=StandaloneFederation(computing, "fed", ("guest", 10000), [("host", 9999)]), -) +ctx = Context("guest", computing=computing, federation=StandaloneFederation( + computing, "fed", ("guest", 10000), [("host", 9999)]), ) -df = pd.read_csv('./../../../../../../examples/data/vehicle_scale_homo_guest.csv') +df = pd.read_csv( + './../../../../../../examples/data/vehicle_scale_homo_guest.csv') df['sample_id'] = [i for i in range(len(df))] -reader = PandasReader(sample_id_name='sample_id', match_id_name="id", label_name="y", dtype="object") +reader = PandasReader( + sample_id_name='sample_id', + match_id_name="id", + label_name="y", + dtype="object") data = reader.to_frame(ctx, df) -runner_conf=get_config_of_default_runner( - algo='fedavg', - model=Sequential( - nn.Linear(18, 10), - nn.ReLU(), - nn.Linear(10 ,4), - nn.Softmax() - ), - loss=nn.CrossEntropyLoss(), - dataset=DatasetLoader('table', 'TableDataset', flatten_label=True, label_dtype='long'), - optimizer=optim.Adam(lr=0.01), - training_args=TrainingArguments(num_train_epochs=50, per_device_train_batch_size=128), - fed_args=FedAVGArguments(), - task_type='binary' - ) +runner_conf = get_config_of_default_runner( + algo='fedavg', + model=Sequential( + nn.Linear( + 18, + 10), + nn.ReLU(), + nn.Linear( + 10, + 4), + nn.Softmax()), + loss=nn.CrossEntropyLoss(), + dataset=DatasetLoader( + 'table', + 'TableDataset', + flatten_label=True, + label_dtype='long'), + optimizer=optim.Adam( + lr=0.01), + training_args=TrainingArguments( + num_train_epochs=50, + per_device_train_batch_size=128), + fed_args=FedAVGArguments(), + task_type='binary') runner = DefaultRunner(**runner_conf) runner.set_context(ctx) diff --git a/python/fate/ml/glm/homo/lr/client.py b/python/fate/ml/glm/homo/lr/client.py index daa8d45a67..12b061aa0f 100644 --- a/python/fate/ml/glm/homo/lr/client.py +++ b/python/fate/ml/glm/homo/lr/client.py @@ -19,13 +19,12 @@ logger = logging.getLogger(__name__) - def homo_lr_loss(pred, labels, dim=1): """ The function assumes that pred has shape (n, num_classes) where each class has its own linear model. labels have shape (n,) and the values are integers denoting the class. """ - + # initialize the loss loss = 0.0 loss_fn = t.nn.BCELoss() @@ -48,8 +47,10 @@ class HomoLRModel(t.nn.Module): def __init__(self, feature_num, label_num=2, l1=0, bias=True) -> None: super().__init__() - assert feature_num >= 2 and isinstance(feature_num, int), "feature_num must be int greater than 2" - assert label_num >= 1 and isinstance(label_num, int), "label_num must be int greater than 1" + assert feature_num >= 2 and isinstance( + feature_num, int), "feature_num must be int greater than 2" + assert label_num >= 1 and isinstance( + label_num, int), "label_num must be int greater than 1" self.models = t.nn.ModuleList() if 2 >= label_num > 0: @@ -81,7 +82,7 @@ def forward(self, x, labels=None): linear_out = self.softmax(linear_out) ret_dict['pred'] = linear_out - + if labels is not None: loss = homo_lr_loss(linear_out, labels, dim=len(self.models)) if self.l1 != 0: @@ -92,19 +93,22 @@ def forward(self, x, labels=None): ret_dict['loss'] = loss return ret_dict - + def to_dict(self): model_dict = { "feature_num": self.models[0].in_features, "label_num": len(self.models), - "state_dict": {k: v.tolist() for k, v in self.state_dict().items()} # convert tensor to list + # convert tensor to list + "state_dict": {k: v.tolist() for k, v in self.state_dict().items()} } return model_dict @classmethod def from_dict(cls, model_dict): model = cls(model_dict["feature_num"], model_dict["label_num"]) - model_state_dict = {k: t.tensor(v) for k, v in model_dict["state_dict"].items()} # convert list back to tensor + model_state_dict = { + k: t.tensor(v) for k, + v in model_dict["state_dict"].items()} # convert list back to tensor model.load_state_dict(model_state_dict) return model @@ -115,15 +119,17 @@ def init_model(model, method='random', fill_val=1.0): elif method == 'ones': init_fn = nn.init.ones_ elif method == 'consts': - init_fn = lambda x: nn.init.constant_(x, fill_val) + def init_fn(x): return nn.init.constant_(x, fill_val) elif method == 'random': init_fn = nn.init.normal_ else: - raise ValueError("Invalid method. Options are: 'zeros', 'ones', 'consts', 'random'") - + raise ValueError( + "Invalid method. Options are: 'zeros', 'ones', 'consts', 'random'") + for name, param in model.named_parameters(): if 'bias' in name: - nn.init.zeros_(param) # usually it's good practice to initialize biases to zero + # usually it's good practice to initialize biases to zero + nn.init.zeros_(param) else: init_fn(param) @@ -153,7 +159,10 @@ def update_params(new_params, default, name='optimizer'): import copy params = copy.deepcopy(default) if not isinstance(new_params, dict): - raise ValueError("{} param dict must be a dict but got {}".format(name, new_params)) + raise ValueError( + "{} param dict must be a dict but got {}".format( + name, new_params)) + def _update(default, new): for key in new.keys(): if key in default: @@ -164,22 +173,47 @@ def _update(default, new): return params -DEFAULT_OPT_PARAM = {'method': 'sgd', 'penalty': 'l2', 'alpha': 0.0, 'optimizer_params': {'lr': 0.01, 'weight_decay': 0}} -DEFAULT_INIT_PARAM = {"method": "random", "fill_val": 1.0, "fit_intercept": True} -DEFAULT_LR_SCHEDULER_PARAM = {'method': 'constant', 'scheduler_params': {'factor': 1.0}} +DEFAULT_OPT_PARAM = { + 'method': 'sgd', + 'penalty': 'l2', + 'alpha': 0.0, + 'optimizer_params': { + 'lr': 0.01, + 'weight_decay': 0}} +DEFAULT_INIT_PARAM = { + "method": "random", + "fill_val": 1.0, + "fit_intercept": True} +DEFAULT_LR_SCHEDULER_PARAM = { + 'method': 'constant', + 'scheduler_params': { + 'factor': 1.0}} class HomoLRClient(HomoModule): - def __init__(self, epochs: int=5, batch_size: int=32, - optimizer_param={'method': 'sgd', 'optimizer_params': {'lr': 0.01, 'weight_decay': 0}}, - learning_rate_scheduler={'method': 'constant', 'scheduler_params': {'factor': 1.0}}, - init_param={"method": "random", "fill_val": 1.0, "fit_intercept": True}, - threshold: float=0.5, - ovr=False, - label_num=None, - ) -> None: - + def __init__( + self, + epochs: int = 5, + batch_size: int = 32, + optimizer_param={ + 'method': 'sgd', + 'optimizer_params': { + 'lr': 0.01, + 'weight_decay': 0}}, + learning_rate_scheduler={ + 'method': 'constant', + 'scheduler_params': { + 'factor': 1.0}}, + init_param={ + "method": "random", + "fill_val": 1.0, + "fit_intercept": True}, + threshold: float = 0.5, + ovr=False, + label_num=None, + ) -> None: + super().__init__() self.df_schema = None self.train_set = None @@ -189,9 +223,14 @@ def __init__(self, epochs: int=5, batch_size: int=32, # set vars self.max_iter = epochs self.batch_size = batch_size - self.optimizer_param = update_params(optimizer_param, DEFAULT_OPT_PARAM, name='optimizer') - self.learning_rate_param = update_params(learning_rate_scheduler, DEFAULT_LR_SCHEDULER_PARAM, name='learning_rate_scheduler') - self.init_param = update_params(init_param, DEFAULT_INIT_PARAM, name='init_param') + self.optimizer_param = update_params( + optimizer_param, DEFAULT_OPT_PARAM, name='optimizer') + self.learning_rate_param = update_params( + learning_rate_scheduler, + DEFAULT_LR_SCHEDULER_PARAM, + name='learning_rate_scheduler') + self.init_param = update_params( + init_param, DEFAULT_INIT_PARAM, name='init_param') self.threshold = threshold self.run_ovr = False self.train_feature_num = None @@ -201,8 +240,10 @@ def __init__(self, epochs: int=5, batch_size: int=32, if self.ovr: if self.label_num is None or self.label_num < 2: - raise ValueError("label_num must be greater than 2 when ovr is True, but got {}".format(self.label_num)) - + raise ValueError( + "label_num must be greater than 2 when ovr is True, but got {}".format( + self.label_num)) + # models & optimizer & schduler self.model = None self.optimizer = None @@ -221,47 +262,68 @@ def __init__(self, epochs: int=5, batch_size: int=32, self.local_mode = False # checkping param - assert self.max_iter > 0 and isinstance(self.max_iter, int), "max_iter must be int greater than 0" + assert self.max_iter > 0 and isinstance( + self.max_iter, int), "max_iter must be int greater than 0" if self.batch_size != -1: - assert self.batch_size > 0 and isinstance(self.batch_size, int), "batch_size must be int greater than 0 or -1" + assert self.batch_size > 0 and isinstance( + self.batch_size, int), "batch_size must be int greater than 0 or -1" assert self.threshold > 0 and self.threshold < 1, "threshold must be float between 0 and 1" - + def _make_dataset(self, data) -> TableDataset: ds = TableDataset(return_dict=True, to_tensor=True) ds.load(data) return ds - - def _make_output_df(self, predict_rs, data: TableDataset, threshold: float): + + def _make_output_df( + self, + predict_rs, + data: TableDataset, + threshold: float): classes = [i for i in range(len(self.model.models))] if len(classes) == 1: # binary: classes = [0, 1] task_type = BINARY if len(classes) == 2 else MULTI - out_df = std_output_df(task_type, predict_rs.predictions, predict_rs.label_ids, threshold=threshold, classes=classes) + out_df = std_output_df( + task_type, + predict_rs.predictions, + predict_rs.label_ids, + threshold=threshold, + classes=classes) out_df = add_ids(out_df, data.get_match_ids(), data.get_sample_ids()) return out_df - + def _check_labels(self, label_set, has_validate=False): - + dataset_descrb = 'train dataset' if not has_validate else 'train and validate dataset' if not self.ovr and len(label_set) > 2: - raise ValueError("please set ovr=True to enable multi-label classification, multiple labels found in {}: {}".format(dataset_descrb, label_set)) + raise ValueError( + "please set ovr=True to enable multi-label classification, multiple labels found in {}: {}".format( + dataset_descrb, label_set)) if not self.ovr and len(label_set) == 2: # 0, 1 is required if 0 not in label_set or 1 not in label_set: # ask for label 0, 1 when running binary classification - raise ValueError("when doing binary classification, lables must be 0, 1, but found in {}'s label set is {}".format(label_set, dataset_descrb)) + raise ValueError( + "when doing binary classification, lables must be 0, 1, but found in {}'s label set is {}".format( + label_set, dataset_descrb)) if self.ovr: if max(label_set) > self.label_num - 1: - # make sure labels start from 0 and not the label indices not exceed the label num parameter - raise ValueError("when doing multi-label classification, labels must start from 0 and not exceed the label num parameter, \ - but {}'s label set is {}, while label num is {}".format(label_set, dataset_descrb, self.label_num)) + # make sure labels start from 0 and not the label indices not + # exceed the label num parameter + raise ValueError( + "when doing multi-label classification, labels must start from 0 and not exceed the label num parameter, \ + but {}'s label set is {}, while label num is {}".format( + label_set, dataset_descrb, self.label_num)) - def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = None) -> None: + def fit(self, ctx: Context, train_data: DataFrame, + validate_data: DataFrame = None) -> None: # check data, must be fate Dataframe - assert isinstance(train_data, DataFrame), "train_data must be a fate DataFrame" + assert isinstance( + train_data, DataFrame), "train_data must be a fate DataFrame" if validate_data is not None: - assert isinstance(validate_data, DataFrame), "validate_data must be a fate DataFrame" + assert isinstance( + validate_data, DataFrame), "validate_data must be a fate DataFrame" self.train_set = self._make_dataset(train_data) if not self.train_set.has_label(): @@ -274,8 +336,10 @@ def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = No if not self.validate_set.has_label(): raise RuntimeError("validate data must have label column") self.validate_feature_num = self.validate_set.features.shape[1] - assert self.train_feature_num == self.validate_feature_num, "train and validate feature num not match: {} vs {}".format(self.train_feature_num, self.validate_feature_num) - unique_label_set = unique_label_set.union(set(self.validate_set.get_classes())) + assert self.train_feature_num == self.validate_feature_num, "train and validate feature num not match: {} vs {}".format( + self.train_feature_num, self.validate_feature_num) + unique_label_set = unique_label_set.union( + set(self.validate_set.get_classes())) self._check_labels(unique_label_set, validate_data is not None) @@ -295,21 +359,37 @@ def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = No # initialize model if self.model is None: fit_intercept = self.init_param["fit_intercept"] - self.model = HomoLRModel(self.train_feature_num, label_num=len(unique_label_set), l1=self.l1, bias=fit_intercept) + self.model = HomoLRModel( + self.train_feature_num, + label_num=len(unique_label_set), + l1=self.l1, + bias=fit_intercept) # init model here - init_model(self.model, method=self.init_param["method"], fill_val=self.init_param["fill_val"]) + init_model( + self.model, + method=self.init_param["method"], + fill_val=self.init_param["fill_val"]) logger.info('model initialized') - logger.info('model parameters are {}'.format(list(self.model.parameters()))) + logger.info( + 'model parameters are {}'.format( + list( + self.model.parameters()))) else: logger.info('model is loaded, warm start training') logger.info('model structure is {}'.format(self.model)) - self.optimizer = optimizer_factory(self.model.parameters(), opt_method, optimizer_params) - self.lr_scheduler = lr_scheduler_factory(self.optimizer, self.learning_rate_param['method'], self.learning_rate_param['scheduler_params']) + self.optimizer = optimizer_factory( + self.model.parameters(), opt_method, optimizer_params) + self.lr_scheduler = lr_scheduler_factory( + self.optimizer, + self.learning_rate_param['method'], + self.learning_rate_param['scheduler_params']) if self.optimizer_state_dict is not None: optimizer_state_dict = { - "state": {k: t.tensor(v) for k, v in self.optimizer_state_dict['state'].items()}, + "state": { + k: t.tensor(v) for k, + v in self.optimizer_state_dict['state'].items()}, "param_groups": self.optimizer_state_dict['param_groups'], } self.optimizer.load_state_dict(optimizer_state_dict) @@ -317,31 +397,51 @@ def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = No # training fed_arg = FedAVGArguments() - train_arg = TrainingArguments(num_train_epochs=self.max_iter, - per_device_train_batch_size=self.batch_size, per_device_eval_batch_size=self.batch_size) - self.trainer = FedAVGCLient(ctx, model=self.model, loss_fn=loss_fn, optimizer=self.optimizer, train_set=self.train_set, - val_set=self.validate_set, training_args=train_arg, fed_args=fed_arg, data_collator=default_data_collator, scheduler=self.lr_scheduler) + train_arg = TrainingArguments( + num_train_epochs=self.max_iter, + per_device_train_batch_size=self.batch_size, + per_device_eval_batch_size=self.batch_size) + self.trainer = FedAVGCLient( + ctx, + model=self.model, + loss_fn=loss_fn, + optimizer=self.optimizer, + train_set=self.train_set, + val_set=self.validate_set, + training_args=train_arg, + fed_args=fed_arg, + data_collator=default_data_collator, + scheduler=self.lr_scheduler) if self.local_mode: # for debugging self.trainer.set_local_mode() self.trainer.train() logger.info('training finished') - + def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: - + if self.model is None: raise ValueError("model is not initialized") self.predict_set = self._make_dataset(predict_data) if self.trainer is None: - batch_size = len(self.predict_set) if self.batch_size == -1 else self.batch_size - train_arg = TrainingArguments(num_train_epochs=self.max_iter, per_device_eval_batch_size=batch_size) - trainer = FedAVGCLient(ctx, train_set=self.predict_set, model=self.model, training_args=train_arg, - fed_args=FedAVGArguments(), data_collator=default_data_collator) + batch_size = len( + self.predict_set) if self.batch_size == -1 else self.batch_size + train_arg = TrainingArguments( + num_train_epochs=self.max_iter, + per_device_eval_batch_size=batch_size) + trainer = FedAVGCLient( + ctx, + train_set=self.predict_set, + model=self.model, + training_args=train_arg, + fed_args=FedAVGArguments(), + data_collator=default_data_collator) trainer.set_local_mode() else: trainer = self.trainer predict_rs = trainer.predict(self.predict_set) - predict_out_df = self._make_output_df(predict_rs, self.predict_set, self.threshold) + predict_out_df = self._make_output_df( + predict_rs, self.predict_set, self.threshold) match_id_name = self.predict_set.get_match_ids().columns[0] sample_id_name = self.predict_set.get_sample_ids().columns[0] return to_fate_df(ctx, match_id_name, sample_id_name, predict_out_df) @@ -351,11 +451,19 @@ def get_model(self) -> ModelIO: if self.model is not None: param['model'] = self.model.to_dict() if self.optimizer is not None: - param['optimizer'] = str(get_torch_bytes(self.optimizer.state_dict())) - - meta = {'batch_size': self.batch_size, 'max_iter': self.max_iter, 'threshold': self.threshold, - 'optimizer_param': self.optimizer_param, 'learning_rate_param': self.learning_rate_param, 'init_param': self.init_param, 'ovr': self.ovr, - 'label_num': self.label_num} + param['optimizer'] = str( + get_torch_bytes( + self.optimizer.state_dict())) + + meta = { + 'batch_size': self.batch_size, + 'max_iter': self.max_iter, + 'threshold': self.threshold, + 'optimizer_param': self.optimizer_param, + 'learning_rate_param': self.learning_rate_param, + 'init_param': self.init_param, + 'ovr': self.ovr, + 'label_num': self.label_num} export_ = ModelIO(data=param, meta=meta) return export_ @@ -363,18 +471,18 @@ def get_model(self) -> ModelIO: def from_model(self, model: ModelIO): model = model.dict() - if not 'data' in model: + if 'data' not in model: raise ('key "data" is not found in the input model dict') - + model_param = model['data'] - if not 'model' in model_param: - raise ValueError("param dict must have key 'model' that contains the model parameter and structure info") + if 'model' not in model_param: + raise ValueError( + "param dict must have key 'model' that contains the model parameter and structure info") self.model = HomoLRModel.from_dict(model_param['model']) if self.ovr: assert len(self.model.models) == self.label_num, '' self.model.l1 = self.l1 if hasattr(model_param, 'optimizer'): - self.optimizer_state_dict = recover_torch_bytes(bytes(model_param['optimizer'], 'utf-8')) + self.optimizer_state_dict = recover_torch_bytes( + bytes(model_param['optimizer'], 'utf-8')) self.loaded_meta = model['meta'] - - diff --git a/python/fate/ml/glm/homo/lr/server.py b/python/fate/ml/glm/homo/lr/server.py index 89b4ee38ad..e143fd21fc 100644 --- a/python/fate/ml/glm/homo/lr/server.py +++ b/python/fate/ml/glm/homo/lr/server.py @@ -8,18 +8,22 @@ logger = logging.getLogger(__name__) + class HomoLRServer(HomoModule): def __init__(self) -> None: pass - def fit(self, ctx: Context, data: DataFrame=None) -> None: - + def fit(self, ctx: Context, data: DataFrame = None) -> None: + server = FedAVGServer(ctx=ctx) logger.info('server class init done, start fed training') server.train() logger.info('homo lr fit done') - def predict(self, ctx: Context, predict_data: DataFrame=None) -> DataFrame: - - logger.info('kkip prediction stage') \ No newline at end of file + def predict( + self, + ctx: Context, + predict_data: DataFrame = None) -> DataFrame: + + logger.info('kkip prediction stage') diff --git a/python/fate/ml/glm/homo/lr/test/local_test.py b/python/fate/ml/glm/homo/lr/test/local_test.py index 6582d5cbe2..2572f129ed 100644 --- a/python/fate/ml/glm/homo/lr/test/local_test.py +++ b/python/fate/ml/glm/homo/lr/test/local_test.py @@ -13,23 +13,29 @@ logger.setLevel(logging.INFO) ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') ch.setFormatter(formatter) logger.addHandler(ch) computing = CSession() -ctx = Context( - "guest", - computing=computing, - federation=StandaloneFederation(computing, "fed", ("guest", 10000), [("host", 9999)]), -) +ctx = Context("guest", computing=computing, federation=StandaloneFederation( + computing, "fed", ("guest", 10000), [("host", 9999)]), ) -df = pd.read_csv('/home/cwj/FATE/FATE-2.0-pure/FATE/examples/data/breast_homo_guest.csv') +df = pd.read_csv( + '/home/cwj/FATE/FATE-2.0-pure/FATE/examples/data/breast_homo_guest.csv') df['sample_id'] = [i for i in range(len(df))] -reader = PandasReader(sample_id_name='sample_id', match_id_name="id", label_name="y", dtype="object") -reader_2 = PandasReader(sample_id_name='sample_id', match_id_name="id", dtype="object") +reader = PandasReader( + sample_id_name='sample_id', + match_id_name="id", + label_name="y", + dtype="object") +reader_2 = PandasReader( + sample_id_name='sample_id', + match_id_name="id", + dtype="object") data = reader.to_frame(ctx, df) # df = data.as_pd_df() @@ -37,9 +43,12 @@ ds = TableDataset(return_dict=True, to_tensor=True) ds.load(data) -client = HomoLRClient(50, 800, optimizer_param={'method': 'adam', 'penalty': 'l1', 'aplha':0.1, 'optimizer_para': {'lr': 0.1}}, init_param={'method': 'random', 'fill_val': 1.0}, - learning_rate_scheduler={'method': 'linear', 'scheduler_params': {'start_factor'}} -) +client = HomoLRClient( + 50, 800, optimizer_param={ + 'method': 'adam', 'penalty': 'l1', 'aplha': 0.1, 'optimizer_para': { + 'lr': 0.1}}, init_param={ + 'method': 'random', 'fill_val': 1.0}, learning_rate_scheduler={ + 'method': 'linear', 'scheduler_params': {'start_factor'}}) client.l2 = 0.01 client.l1 = 0.01 client.local_mode = True @@ -47,4 +56,3 @@ export_model = client.get_model() pred = client.predict(ctx, data) # pred_2 = client.predict(ctx, data_2) - diff --git a/python/fate/ml/nn/algo/homo/fedavg.py b/python/fate/ml/nn/algo/homo/fedavg.py index 15ddfeddc4..debc91fcb4 100644 --- a/python/fate/ml/nn/algo/homo/fedavg.py +++ b/python/fate/ml/nn/algo/homo/fedavg.py @@ -30,39 +30,56 @@ class FedAVGArguments(FedArguments): secure_aggregate: bool Whether to use secure aggregation or not. """ - + weighted_aggregate: bool = field(default=True) secure_aggregate: bool = field(default=False) class FedAVGCLient(FedTrainerClient): - - def __init__(self, + + def __init__(self, ctx: Context, - model: Module, - training_args: TrainingArguments, fed_args: FedArguments, - train_set: Dataset, - val_set: Dataset = None, - loss_fn: Module = None, - optimizer: Optimizer = None, - scheduler: _LRScheduler = None, - callbacks: List[TrainerCallback] = [], - data_collator: Callable=None, + model: Module, + training_args: TrainingArguments, fed_args: FedArguments, + train_set: Dataset, + val_set: Dataset = None, + loss_fn: Module = None, + optimizer: Optimizer = None, + scheduler: _LRScheduler = None, + callbacks: List[TrainerCallback] = [], + data_collator: Callable = None, tokenizer: Optional[PreTrainedTokenizer] = None, - use_hf_default_behavior: bool = False, - compute_metrics: Callable = None, + use_hf_default_behavior: bool = False, + compute_metrics: Callable = None, local_mode: bool = False ): - - super().__init__(ctx, model, training_args, fed_args, train_set, val_set, loss_fn, optimizer, data_collator, scheduler, - tokenizer, callbacks, use_hf_default_behavior, - compute_metrics=compute_metrics, local_mode=local_mode) + + super().__init__( + ctx, + model, + training_args, + fed_args, + train_set, + val_set, + loss_fn, + optimizer, + data_collator, + scheduler, + tokenizer, + callbacks, + use_hf_default_behavior, + compute_metrics=compute_metrics, + local_mode=local_mode) def init_aggregator(self): sample_num = len(self.train_dataset) - aggregator = PlainTextAggregatorClient(self.ctx, aggregator_name='fed_avg', aggregate_type='weighted_mean', sample_num=sample_num) + aggregator = PlainTextAggregatorClient( + self.ctx, + aggregator_name='fed_avg', + aggregate_type='weighted_mean', + sample_num=sample_num) return aggregator - + @time_decorator('FedAVG') def on_federation( self, @@ -77,27 +94,33 @@ def on_federation( control: Optional[TrainerControl] = None, state: Optional[TrainerState] = None, **kwargs): - + aggregator.model_aggregation(model) class FedAVGServer(FedTrainerServer): - def __init__(self, + def __init__(self, ctx: Context, - training_args: TrainingArguments = None, + training_args: TrainingArguments = None, fed_args: FedArguments = None, parameter_alignment: bool = True, local_mode: bool = False ) -> None: - + super().__init__(ctx, training_args, fed_args, parameter_alignment, local_mode) def init_aggregator(self): - aggregator = PlainTextAggregatorServer(self.ctx, aggregator_name='fed_avg') + aggregator = PlainTextAggregatorServer( + self.ctx, aggregator_name='fed_avg') return aggregator - def on_federation(self, ctx: Context, aggregator: Aggregator, fed_args: FedArguments, args: TrainingArguments): + def on_federation( + self, + ctx: Context, + aggregator: Aggregator, + fed_args: FedArguments, + args: TrainingArguments): aggregator.model_aggregation() diff --git a/python/fate/ml/nn/dataset/base.py b/python/fate/ml/nn/dataset/base.py index 7fd161d91d..a56c503afb 100644 --- a/python/fate/ml/nn/dataset/base.py +++ b/python/fate/ml/nn/dataset/base.py @@ -20,7 +20,7 @@ def __getitem__(self, item): def __len__(self): raise NotImplementedError() - + def has_label(self) -> bool: pass @@ -29,7 +29,7 @@ def get_classes(self) -> list: def get_match_ids(self) -> pd.DataFrame: pass - + def get_sample_ids(self) -> pd.DataFrame: pass @@ -37,4 +37,4 @@ def get_sample_id_name(self) -> str: pass def get_match_id_name(self) -> str: - pass \ No newline at end of file + pass diff --git a/python/fate/ml/nn/dataset/table.py b/python/fate/ml/nn/dataset/table.py index 0414570349..52198c5259 100644 --- a/python/fate/ml/nn/dataset/table.py +++ b/python/fate/ml/nn/dataset/table.py @@ -28,10 +28,16 @@ class TableDataset(Dataset): """ def __init__( - self, label_col=None, match_id_col=None, sample_id_col=None, - feature_dtype="float", label_dtype="float", label_shape=None, flatten_label=False, - to_tensor=True, return_dict=False - ): + self, + label_col=None, + match_id_col=None, + sample_id_col=None, + feature_dtype="float", + label_dtype="float", + label_shape=None, + flatten_label=False, + to_tensor=True, + return_dict=False): super(TableDataset, self).__init__() self.features: np.ndarray = None @@ -46,13 +52,12 @@ def __init__( self.to_tensor = to_tensor self.return_dict = return_dict if label_shape is not None: - assert isinstance(label_shape, tuple) or isinstance(label_shape, list), "label shape is {}".format( - label_shape - ) + assert isinstance(label_shape, tuple) or isinstance( + label_shape, list), "label shape is {}".format(label_shape) self.label_shape = label_shape self.flatten_label = flatten_label - # sample ids, match ids + # sample ids, match ids self.sample_ids = None self.match_ids = None @@ -66,7 +71,8 @@ def check_dtype(dtype): if dtype is not None: avail = ["long", "int", "float", "double"] - assert dtype in avail, "available dtype is {}, but got {}".format(avail, dtype) + assert dtype in avail, "available dtype is {}, but got {}".format( + avail, dtype) if dtype == "long": return np.int64 if dtype == "int": @@ -110,30 +116,39 @@ def load(self, data_or_path): # automatically set id columns if self.match_id_col is not None: if self.match_id_col not in self.origin_table: - raise ValueError("match id column {} not found".format(self.match_id_col)) + raise ValueError( + "match id column {} not found".format( + self.match_id_col)) else: self.match_ids = self.origin_table[[self.match_id_col]] - self.origin_table = self.origin_table.drop(columns=[self.match_id_col]) + self.origin_table = self.origin_table.drop( + columns=[self.match_id_col]) else: match_id_col_cadidaites = ["id", "sid"] for id_col in match_id_col_cadidaites: if id_col in self.origin_table: self.match_ids = self.origin_table[[id_col]] - self.origin_table = self.origin_table.drop(columns=[id_col]) + self.origin_table = self.origin_table.drop(columns=[ + id_col]) break if self.match_ids is None: - logger.info("match id column not found, no match id will be set") - + logger.info( + "match id column not found, no match id will be set") + # generate sample ids if self.sample_id_col is not None: if self.sample_id_col not in self.origin_table: - raise ValueError("sample id column {} not found".format(self.sample_id_col)) + raise ValueError( + "sample id column {} not found".format( + self.sample_id_col)) self.sample_ids = self.origin_table[[self.sample_id_col]] - self.origin_table = self.origin_table.drop(columns=[self.sample_id_col]) + self.origin_table = self.origin_table.drop( + columns=[self.sample_id_col]) else: self.sample_ids = pd.DataFrame() self.sample_ids["sample_id"] = range(len(self.origin_table)) - logger.info("sample id column not found, generate sample id from 0 to {}".format(len(self.origin_table))) + logger.info("sample id column not found, generate sample id from 0 to {}".format( + len(self.origin_table))) # infer column name label = self.label_col @@ -144,23 +159,26 @@ def load(self, data_or_path): break if label is None: self.with_label = False - logger.info('found no "y"/"label"/"target" in input table, no label will be set') + logger.info( + 'found no "y"/"label"/"target" in input table, no label will be set') else: if label not in self.origin_table: - raise ValueError("label column {} not found in input table".format(label)) - + raise ValueError( + "label column {} not found in input table".format(label)) + if self.label is not None: self.label = self.origin_table[[label]].values self.origin_table = self.origin_table.drop(columns=[label]) self.features = self.origin_table.values - + elif isinstance(data_or_path, DataFrame): schema = data_or_path.schema sample_id = schema.sample_id_name match_id = schema.match_id_name label = schema.label_name if label is None: - logger.info("label column is None, not provided in the uploaded data") + logger.info( + "label column is None, not provided in the uploaded data") pd_df = data_or_path.as_pd_df() if label is None: labels = None @@ -174,7 +192,6 @@ def load(self, data_or_path): self.sample_ids = sample_ids self.match_ids = match_ids self.features = features.values - if self.label is not None: @@ -188,7 +205,7 @@ def load(self, data_or_path): if self.flatten_label: self.label = self.label.flatten() - + else: self.label = None @@ -199,25 +216,28 @@ def get_classes(self): if self.label is not None: return np.unique(self.label).tolist() else: - raise ValueError("no label found, please check if self.label is set") + raise ValueError( + "no label found, please check if self.label is set") def get_sample_ids(self) -> pd.DataFrame: return self.sample_ids def get_match_ids(self) -> pd.DataFrame: return self.match_ids - + def get_sample_id_name(self) -> str: - if self.sample_ids is not None and isinstance(self.sample_ids, pd.DataFrame): + if self.sample_ids is not None and isinstance( + self.sample_ids, pd.DataFrame): return self.sample_ids.columns[0] else: raise ValueError('Cannot get sample id name') - + def get_match_id_name(self) -> str: - if self.match_ids is not None and isinstance(self.match_ids, pd.DataFrame): + if self.match_ids is not None and isinstance( + self.match_ids, pd.DataFrame): return self.match_ids.columns[0] else: raise ValueError('Cannot get match id name') - + def has_label(self) -> bool: return self.label is not None diff --git a/python/fate/ml/nn/model_zoo/multi_model.py b/python/fate/ml/nn/model_zoo/multi_model.py index 5713918b47..ca0714f7f4 100644 --- a/python/fate/ml/nn/model_zoo/multi_model.py +++ b/python/fate/ml/nn/model_zoo/multi_model.py @@ -1,5 +1,6 @@ from torch import nn + class Multi(nn.Module): def __init__(self, feat=18, class_num=4) -> None: @@ -16,4 +17,3 @@ def forward(self, x): return self.model(x) else: return nn.Softmax(dim=-1)(self.model(x)) - \ No newline at end of file diff --git a/python/fate/ml/nn/trainer/trainer_base.py b/python/fate/ml/nn/trainer/trainer_base.py index 208d782152..1ef4d2d94a 100644 --- a/python/fate/ml/nn/trainer/trainer_base.py +++ b/python/fate/ml/nn/trainer/trainer_base.py @@ -49,20 +49,25 @@ def wrapper(*args, **kwargs): def get_ith_checkpoint(directory, i): # List all files in the directory files = os.listdir(directory) - + # Filter for checkpoint directories checkpoint_dirs = [f for f in files if f.startswith("checkpoint-")] - + # Extract the numbers from the checkpoint directory names - checkpoint_numbers = [int(re.search(r'\d+', dir).group()) for dir in checkpoint_dirs] - - # Pair the checkpoint directories with their numbers and sort by the numbers - sorted_checkpoints = sorted(zip(checkpoint_dirs, checkpoint_numbers), key=lambda x: x[1]) - + checkpoint_numbers = [int(re.search(r'\d+', dir).group()) + for dir in checkpoint_dirs] + + # Pair the checkpoint directories with their numbers and sort by the + # numbers + sorted_checkpoints = sorted( + zip(checkpoint_dirs, checkpoint_numbers), key=lambda x: x[1]) + if i < 0: - raise ValueError(f"checkpoint idx i must be greater than or equal to 0, got {i}") + raise ValueError( + f"checkpoint idx i must be greater than or equal to 0, got {i}") if i > len(sorted_checkpoints) - 1: - raise ValueError(f"checkpoint number is {len(sorted_checkpoints)}, but got {i}") + raise ValueError( + f"checkpoint number is {len(sorted_checkpoints)}, but got {i}") # Return the name of the ith checkpoint directory return sorted_checkpoints[i][0] @@ -82,7 +87,8 @@ class FedArguments(object): """ The argument for Fed algorithm """ - aggregate_strategy: AggregateStrategy = field(default=AggregateStrategy.EPOCH.value) + aggregate_strategy: AggregateStrategy = field( + default=AggregateStrategy.EPOCH.value) aggregate_freq: int = field(default=1) def to_dict(self): @@ -91,7 +97,8 @@ def to_dict(self): the token values by removing their value. """ # filter out fields that are defined as field(init=False) - d = dict((field.name, getattr(self, field.name)) for field in fields(self) if field.init) + d = dict((field.name, getattr(self, field.name)) + for field in fields(self) if field.init) for k, v in d.items(): if isinstance(v, Enum): @@ -105,7 +112,7 @@ def to_dict(self): @dataclass class TrainingArguments(_hf_TrainingArguments): - + # in fate-2.0, we will control the output dir when using pipeline output_dir: str = field(default='./') disable_tqdm: bool = field(default=True) @@ -115,10 +122,10 @@ class TrainingArguments(_hf_TrainingArguments): logging_dir: str = field(default=None) checkpoint_idx: int = field(default=None) # by default we use constant learning rate, the same as FATE-1.X - lr_scheduler_type: str = field(default="constant") + lr_scheduler_type: str = field(default="constant") def __post_init__(self): - + # Always use default values for hub-related attributes self.push_to_hub = False self.hub_model_id = None @@ -128,7 +135,7 @@ def __post_init__(self): self.push_to_hub_model_id = None self.push_to_hub_organization = None self.push_to_hub_token = None - + super().__post_init__() def to_dict(self): @@ -140,7 +147,9 @@ def to_dict(self): default_args = _hf_TrainingArguments(output_dir='./').to_dict() # Filter out args that are equal to their default values - set_args = {name: value for name, value in all_args.items() if value != default_args.get(name)} + set_args = { + name: value for name, + value in all_args.items() if value != default_args.get(name)} return set_args @@ -236,7 +245,7 @@ def on_step_begin( aggregator: Aggregator, fed_args: FedArguments, args: TrainingArguments, - model: Optional[nn.Module] = None, + model: Optional[nn.Module] = None, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None, dataloader: Optional[Tuple[DataLoader]] = None, @@ -246,59 +255,68 @@ def on_step_begin( pass def on_step_end( - self, - ctx: Context, - aggregator: Aggregator, - fed_args: FedArguments, - args: TrainingArguments, - model: Optional[nn.Module] = None, - optimizer: Optional[Optimizer] = None, - scheduler: Optional[_LRScheduler] = None, - dataloader: Optional[Tuple[DataLoader]] = None, - control: Optional[TrainerControl] = None, - state: Optional[TrainerState] = None, - **kwargs): + self, + ctx: Context, + aggregator: Aggregator, + fed_args: FedArguments, + args: TrainingArguments, + model: Optional[nn.Module] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional[_LRScheduler] = None, + dataloader: Optional[Tuple[DataLoader]] = None, + control: Optional[TrainerControl] = None, + state: Optional[TrainerState] = None, + **kwargs): pass class FedCallbackInterface(object): def on_federation( - self, - ctx: Context, - aggregator: Aggregator, - fed_args: FedArguments, - args: TrainingArguments, - model: Optional[nn.Module] = None, - optimizer: Optional[Optimizer] = None, - scheduler: Optional[_LRScheduler] = None, - dataloader: Optional[Tuple[DataLoader]] = None, - control: Optional[TrainerControl] = None, - state: Optional[TrainerState] = None, - **kwargs): + self, + ctx: Context, + aggregator: Aggregator, + fed_args: FedArguments, + args: TrainingArguments, + model: Optional[nn.Module] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional[_LRScheduler] = None, + dataloader: Optional[Tuple[DataLoader]] = None, + control: Optional[TrainerControl] = None, + state: Optional[TrainerState] = None, + **kwargs): pass def init_aggregator(self): - raise NotImplementedError('init_aggregator() must be implemented in subclass, init aggregator here') + raise NotImplementedError( + 'init_aggregator() must be implemented in subclass, init aggregator here') # I dont like huggingface logging class LogSuppressFilter(logging.Filter): def filter(self, record): suppress_list = set( - ["\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"] - ) + ["\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"]) if record.getMessage() in suppress_list: return False return True - -def compute_max_aggregation(fed_args: FedArguments, max_epoch: int, max_steps: int, epochs_trained: int, steps_trained: int) -> int: - - assert max_epoch > epochs_trained and max_epoch > 0, 'max_epoch must be greater than epochs_trained: {} and greater than 0'.format(epochs_trained) - assert max_steps > steps_trained and max_steps > 0, 'max_steps must be greater than steps_trained: {} and greater than 0'.format(steps_trained) - if isinstance(fed_args.aggregate_freq, float) and fed_args.aggregate_freq < 1 and fed_args.aggregate_freq > 0: +def compute_max_aggregation( + fed_args: FedArguments, + max_epoch: int, + max_steps: int, + epochs_trained: int, + steps_trained: int) -> int: + + assert max_epoch > epochs_trained and max_epoch > 0, 'max_epoch must be greater than epochs_trained: {} and greater than 0'.format( + epochs_trained) + assert max_steps > steps_trained and max_steps > 0, 'max_steps must be greater than steps_trained: {} and greater than 0'.format( + steps_trained) + + if isinstance( + fed_args.aggregate_freq, + float) and fed_args.aggregate_freq < 1 and fed_args.aggregate_freq > 0: if fed_args.aggregate_strategy == AggregateStrategy.EPOCH.value: aggregate_freq = int(max_epoch / int(1 / fed_args.aggregate_freq)) elif fed_args.aggregate_strategy == AggregateStrategy.STEP.value: @@ -307,7 +325,8 @@ def compute_max_aggregation(fed_args: FedArguments, max_epoch: int, max_steps: i elif isinstance(fed_args.aggregate_freq, int) and fed_args.aggregate_freq > 0: aggregate_freq = fed_args.aggregate_freq else: - raise ValueError('aggregate_freq must be a positive integer or a float between 0 and 1') + raise ValueError( + 'aggregate_freq must be a positive integer or a float between 0 and 1') if fed_args.aggregate_strategy == AggregateStrategy.EPOCH.value: max_aggregation = int((max_epoch - epochs_trained) / aggregate_freq) @@ -317,12 +336,20 @@ def compute_max_aggregation(fed_args: FedArguments, max_epoch: int, max_steps: i raise ValueError('aggregate_strategy must be either "epoch" or "step"') return max_aggregation, aggregate_freq - + class AggregationChecker: - def __init__(self, fed_args, max_aggregation, aggregate_freq, max_epoch: int, max_steps: int, epochs_trained: int, steps_trained: int): - + def __init__( + self, + fed_args, + max_aggregation, + aggregate_freq, + max_epoch: int, + max_steps: int, + epochs_trained: int, + steps_trained: int): + self.fed_args = fed_args self.max_epoch = max_epoch self.max_steps = max_steps @@ -333,7 +360,8 @@ def __init__(self, fed_args, max_aggregation, aggregate_freq, max_epoch: int, ma self.max_aggregation = max_aggregation def report(self): - logger.info(f'Aggregation count: {self.aggregation_count} / {self.max_aggregation}') + logger.info( + f'Aggregation count: {self.aggregation_count} / {self.max_aggregation}') def should_aggregate(self, state: TrainerState) -> bool: @@ -349,12 +377,14 @@ def should_aggregate(self, state: TrainerState) -> bool: strategy = self.fed_args.aggregate_strategy if strategy == AggregateStrategy.EPOCH.value: - if cur_epoch > self.epochs_trained and (cur_epoch - self.epochs_trained) % self.aggregate_freq == 0: + if cur_epoch > self.epochs_trained and ( + cur_epoch - self.epochs_trained) % self.aggregate_freq == 0: self.aggregation_count += 1 self.report() return True elif strategy == AggregateStrategy.STEP.value: - if cur_step > self.steps_trained and (cur_step - self.steps_trained) % self.aggregate_freq == 0: + if cur_step > self.steps_trained and ( + cur_step - self.steps_trained) % self.aggregate_freq == 0: self.aggregation_count += 1 self.report() return True @@ -364,7 +394,13 @@ def should_aggregate(self, state: TrainerState) -> bool: class FedParameterAlignCallback(TrainerCallback): - def __init__(self, trainer_class, ctx: Context, training_args: TrainingArguments, fed_args: FedArguments, is_server: bool = False) -> None: + def __init__( + self, + trainer_class, + ctx: Context, + training_args: TrainingArguments, + fed_args: FedArguments, + is_server: bool = False) -> None: super().__init__() self.trainer_class = trainer_class self.ctx = ctx @@ -379,7 +415,11 @@ def __init__(self, trainer_class, ctx: Context, training_args: TrainingArguments def get_aggregation_checker(self): return self._aggregation_checker - def _client_send_parameters(self, state: TrainerState, args, train_dataloader): + def _client_send_parameters( + self, + state: TrainerState, + args, + train_dataloader): # client need to compute: epochs, max_steps, num_step_per_epoch, trained_epoch, trained_steps # and sync with server @@ -392,28 +432,32 @@ def _client_send_parameters(self, state: TrainerState, args, train_dataloader): num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) if args.max_steps > 0: max_steps = args.max_steps - num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( - args.max_steps % num_update_steps_per_epoch > 0 - ) + num_train_epochs = args.max_steps // num_update_steps_per_epoch + \ + int(args.max_steps % num_update_steps_per_epoch > 0) else: - max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + max_steps = math.ceil( + args.num_train_epochs * + num_update_steps_per_epoch) num_train_epochs = math.ceil(args.num_train_epochs) elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size max_steps = args.max_steps - # Setting a very large number of epochs so we go as many times as necessary over the iterator. + # Setting a very large number of epochs so we go as many times as + # necessary over the iterator. num_train_epochs = sys.maxsize num_update_steps_per_epoch = max_steps # warm start related variables epochs_trained = state.global_step // num_update_steps_per_epoch if not args.ignore_data_skip: - steps_trained_in_current_epoch = state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch = state.global_step % ( + num_update_steps_per_epoch) steps_trained_in_current_epoch *= args.gradient_accumulation_steps else: steps_trained_in_current_epoch = 0 - max_aggregation, aggregate_freq = compute_max_aggregation(self.fed_args, num_train_epochs, max_steps, epochs_trained, state.global_step) + max_aggregation, aggregate_freq = compute_max_aggregation( + self.fed_args, num_train_epochs, max_steps, epochs_trained, state.global_step) logger.info('computed max_aggregation is {}'.format(max_aggregation)) # send parameters @@ -430,10 +474,18 @@ def _client_send_parameters(self, state: TrainerState, args, train_dataloader): logger.info('parameters is {}'.format(parameters)) - self.ctx.arbiter.put(self._suffix + '_' + str(self._send_count), parameters) + self.ctx.arbiter.put(self._suffix + '_' + + str(self._send_count), parameters) self._send_count += 1 self._parameters = parameters - self.trainer_class.aggregation_checker = AggregationChecker(self.fed_args, max_aggregation, aggregate_freq, num_train_epochs, max_steps, epochs_trained, state.global_step) + self.trainer_class.aggregation_checker = AggregationChecker( + self.fed_args, + max_aggregation, + aggregate_freq, + num_train_epochs, + max_steps, + epochs_trained, + state.global_step) def get_parameters(self): return self._parameters @@ -454,18 +506,20 @@ def _check_fed_strategy(self, parameters): all_cilent_strategy = [p['aggregation_strategy'] for p in parameters] logger.info('all client strategies are {}'.format(all_cilent_strategy)) strategy_flag = self._startegy_type(all_cilent_strategy[0]) - for p in all_cilent_strategy[1: ]: + for p in all_cilent_strategy[1:]: if self._startegy_type(p) != strategy_flag: - raise ValueError('fed strategy not match, all clients has to have the same strategy: by epoch(epoch) or by step(step, progress_percentage),\n \ + raise ValueError( + 'fed strategy not match, all clients has to have the same strategy: by epoch(epoch) or by step(step, progress_percentage),\n \ please check: {}'.format(all_cilent_strategy)) - + return strategy_flag def _check_federation_round(self, parameters): - + agg_round = [p['max_aggregation'] for p in parameters] if len(set(agg_round)) != 1: - raise ValueError('federation round not match, all clients has to have the same aggregation round,\n \ + raise ValueError( + 'federation round not match, all clients has to have the same aggregation round,\n \ please check: {}'.format(agg_round)) return agg_round[0] @@ -480,10 +534,16 @@ def _server_check_parameters(self): agg_round = self._check_federation_round(para) self._parameters = {'max_aggregation': agg_round} - def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): + if self.trainer_class.local_mode: - logger.info('FedParameterAlignCallback: local model, skipping federated parameter checking') + logger.info( + 'FedParameterAlignCallback: local model, skipping federated parameter checking') return else: if self.is_server: @@ -503,7 +563,6 @@ def on_log(self, args, state, control, logs=None, **kwargs): class CallbackWrapper(TrainerCallback): - def __init__(self, ctx: Context, wrapped_trainer: 'StdFedTrainerMixin'): self.ctx = ctx self.wrapped_trainer = wrapped_trainer @@ -517,39 +576,72 @@ def _call_wrapped(self, event_name: str, **kwargs): eval_dataloader = kwargs.pop('eval_dataloader', None) dataloaders = tuple(filter(None, (train_dataloader, eval_dataloader))) kwargs['dataloader'] = dataloaders - return event(self.ctx, self.wrapped_trainer.aggregator, self.fed_arg, **kwargs) - + return event( + self.ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, + **kwargs) + class FedCallbackWrapper(CallbackWrapper): def __init__(self, ctx: Context, wrapped_trainer: 'StdFedTrainerMixin'): super().__init__(ctx, wrapped_trainer) - def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): # initialize aggregator - # doesnot call wrapper here, make sure aggregator is not called before it is initialized + # doesnot call wrapper here, make sure aggregator is not called before + # it is initialized if self.wrapped_trainer.local_mode: - logger.info('local mode, skip federation aggregator initialization, aggregator will be None') + logger.info( + 'local mode, skip federation aggregator initialization, aggregator will be None') else: self.wrapped_trainer.aggregator = self.wrapped_trainer.init_aggregator() - def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + def on_epoch_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): if self.wrapped_trainer.local_mode: return if self.fed_arg.aggregate_strategy == AggregateStrategy.EPOCH.value: - if self.wrapped_trainer.aggregation_checker.should_aggregate(state): + if self.wrapped_trainer.aggregation_checker.should_aggregate( + state): logger.info('aggregation on epoch end') - return self._call_wrapped('on_federation', args=args, state=state, control=control, **kwargs) - - def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + return self._call_wrapped( + 'on_federation', + args=args, + state=state, + control=control, + **kwargs) + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs): if self.wrapped_trainer.local_mode: return if self.fed_arg.aggregate_strategy == AggregateStrategy.STEP.value: - if self.wrapped_trainer.aggregation_checker.should_aggregate(state): + if self.wrapped_trainer.aggregation_checker.should_aggregate( + state): logger.info('aggregation on step end') - return self._call_wrapped('on_federation', args=args, state=state, control=control, **kwargs) + return self._call_wrapped( + 'on_federation', + args=args, + state=state, + control=control, + **kwargs) + - class ShortcutCallbackWrapper(CallbackWrapper): def __init__(self, ctx: Context, wrapped_trainer: 'StdFedTrainerMixin'): @@ -645,7 +737,7 @@ def on_step_end( state=state, control=control, **kwargs) - + logger.addFilter(LogSuppressFilter()) @@ -658,7 +750,7 @@ def on_step_end( class StdFedTrainerMixin(ShortcutCallBackInterFace, FedCallbackInterface): def __init__(self, - ctx: Context, + ctx: Context, model: nn.Module, training_args: TrainingArguments, fed_args: FedArguments, @@ -672,13 +764,13 @@ def __init__(self, use_hf_default_behavior: bool = False, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, local_mode: bool = False, - parameter_alignment = True + parameter_alignment=True ): - + assert isinstance( callbacks, list), 'callback must be a list containing Callback objects, but got {}'.format( callbacks) - + self.ctx: Context = ctx self.local_mode = local_mode self.parameter_alignment = parameter_alignment @@ -694,19 +786,20 @@ def __init__(self, # for callback class to check if aggregation is needed self.aggregation_checker: AggregationChecker = None - + def _compute_metrics_warp_func(self, *args, **kwargs): - + if self._user_compute_metric_func is None: return {} else: eval_result = self._user_compute_metric_func(*args, **kwargs) # Do some FATEBoard Callback here return eval_result - + def _handle_callback(self, callback_handler, new_callbacks): - # remove default logger.infoer callback, need to use our logging strategy + # remove default logger.infoer callback, need to use our logging + # strategy new_callback_list = [] for i in callback_handler.callbacks: # if not isinstance(i, logger.infoerCallback): @@ -717,8 +810,8 @@ def _handle_callback(self, callback_handler, new_callbacks): def _add_fate_callback(self, callback_handler): # the callback handler is Trainer.callback_handler - # call order: - # fed callback aggregator init(once), parameter check(once), + # call order: + # fed callback aggregator init(once), parameter check(once), # on federation of fedcallback # callbacks of shortcutcallback new_callback_list = [] @@ -731,15 +824,19 @@ def _add_fate_callback(self, callback_handler): callback_handler.callbacks = new_callback_list callback_handler.callbacks.append(FedCallbackWrapper(self.ctx, self)) if self.parameter_alignment: - callback_handler.callbacks.append(FedParameterAlignCallback(self, - self.ctx, - fed_args=self._fed_args, - training_args=self._args, - is_server=False)) + callback_handler.callbacks.append( + FedParameterAlignCallback( + self, + self.ctx, + fed_args=self._fed_args, + training_args=self._args, + is_server=False)) else: - logger.warning('Parameter alignment is disabled, this may cause fed-training failure') - callback_handler.callbacks.append(ShortcutCallbackWrapper(self.ctx, self)) - + logger.warning( + 'Parameter alignment is disabled, this may cause fed-training failure') + callback_handler.callbacks.append( + ShortcutCallbackWrapper(self.ctx, self)) + def _remove_fed_callback(self, callback_class): self.callback_handler.callbacks = [ c for c in self.callback_handler.callbacks if not isinstance( @@ -756,7 +853,7 @@ def set_fed_mode(self): @property def aggregator(self): return self._aggregator - + @aggregator.setter def aggregator(self, value): self._aggregator = value @@ -766,6 +863,7 @@ def aggregator(self, value): Base Classes of Client/Sever Trainer """ + class FedTrainerClient(Trainer, StdFedTrainerMixin): """ @@ -792,30 +890,29 @@ def __init__(self, use_hf_default_behavior: bool = False, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, local_mode: bool = False, - parameter_alignment = True + parameter_alignment=True ): - + # in case you forget to set evaluation_strategy if val_set is not None and training_args.evaluation_strategy == 'no': training_args.evaluation_strategy = 'epoch' - - StdFedTrainerMixin.__init__(self, - ctx=ctx, - model=model, - loss_fn=loss_fn, - optimizer=optimizer, - training_args=training_args, - fed_args=fed_args, - train_set=train_set, - val_set=val_set, - scheduler=scheduler, - callbacks=callbacks, - use_hf_default_behavior=use_hf_default_behavior, - compute_metrics=compute_metrics, - local_mode=local_mode, - parameter_alignment=parameter_alignment - ) + StdFedTrainerMixin.__init__( + self, + ctx=ctx, + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + training_args=training_args, + fed_args=fed_args, + train_set=train_set, + val_set=val_set, + scheduler=scheduler, + callbacks=callbacks, + use_hf_default_behavior=use_hf_default_behavior, + compute_metrics=compute_metrics, + local_mode=local_mode, + parameter_alignment=parameter_alignment) if data_collator is None: data_collator = _utils.collate.default_collate @@ -824,19 +921,21 @@ def __init__(self, if self._args.checkpoint_idx is not None: checkpoint_path = self._args.resume_from_checkpoint if checkpoint_path is not None and os.path.exists(checkpoint_path): - checkpoint_folder = get_ith_checkpoint(checkpoint_path, self._args.checkpoint_idx) - self._args.resume_from_checkpoint = os.path.join(checkpoint_path, checkpoint_folder) + checkpoint_folder = get_ith_checkpoint( + checkpoint_path, self._args.checkpoint_idx) + self._args.resume_from_checkpoint = os.path.join( + checkpoint_path, checkpoint_folder) Trainer.__init__(self, - model=model, - args=self._args, - train_dataset=train_set, - eval_dataset=val_set, - data_collator=data_collator, - optimizers=[optimizer, scheduler], - tokenizer=tokenizer, - compute_metrics=self._compute_metrics_warp_func - ) + model=model, + args=self._args, + train_dataset=train_set, + eval_dataset=val_set, + data_collator=data_collator, + optimizers=[optimizer, scheduler], + tokenizer=tokenizer, + compute_metrics=self._compute_metrics_warp_func + ) self._add_fate_callback(self.callback_handler) @@ -849,14 +948,17 @@ def compute_loss(self, model, inputs, **kwargs): return super().compute_loss(model, inputs, **kwargs) else: # (features, labels), this format is used in FATE-1.x - if isinstance(inputs, tuple) or isinstance(inputs, list) and len(inputs) == 2: + if isinstance( + inputs, + tuple) or isinstance( + inputs, + list) and len(inputs) == 2: feats, labels = inputs output = model(feats) loss = self.loss_func(output, labels) return loss else: return super().compute_loss(model, inputs, **kwargs) - def prediction_step(self, model: nn.Module, @@ -865,13 +967,17 @@ def prediction_step(self, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None): - + if self._use_hf_default_behavior: return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) else: # (features, labels), this format is used in FATE-1.x # now the model is in eval status - if isinstance(inputs, tuple) or isinstance(inputs, list) and len(inputs) == 2: + if isinstance( + inputs, + tuple) or isinstance( + inputs, + list) and len(inputs) == 2: with torch.no_grad(): feats, labels = inputs logits = model(feats) @@ -882,25 +988,27 @@ def prediction_step(self, class FedTrainerServer(object): - def __init__(self, + def __init__(self, ctx: Context, - training_args: TrainingArguments = None, + training_args: TrainingArguments = None, fed_args: FedArguments = None, parameter_alignment: bool = True, local_mode: bool = False ) -> None: - + self.ctx = ctx self.parameter_alignment = parameter_alignment self.local_mode = local_mode self._args = training_args self._fed_args = fed_args self._max_steps = None - self._parameter_check_callback = FedParameterAlignCallback(self, self.ctx, None, None, is_server=True) + self._parameter_check_callback = FedParameterAlignCallback( + self, self.ctx, None, None, is_server=True) self._max_aggregation = None def set_fed_context(self, ctx: Context): - assert isinstance(ctx, Context), 'ctx must be a Context object, but got {}'.format(ctx) + assert isinstance( + ctx, Context), 'ctx must be a Context object, but got {}'.format(ctx) self.ctx = ctx def set_local_mode(self): @@ -913,43 +1021,82 @@ def set_fed_mode(self): def init_aggregator(self): return None - - def on_train_end(self, ctx: Context, aggregator: Aggregator, fed_args: FedArguments, args: TrainingArguments): + + def on_train_end( + self, + ctx: Context, + aggregator: Aggregator, + fed_args: FedArguments, + args: TrainingArguments): pass - def on_train_begin(self, ctx: Context, aggregator: Aggregator, fed_args: FedArguments, args: TrainingArguments): + def on_train_begin( + self, + ctx: Context, + aggregator: Aggregator, + fed_args: FedArguments, + args: TrainingArguments): pass - def on_init_end(self, ctx: Context, aggregator: Aggregator, fed_args: FedArguments, args: TrainingArguments): + def on_init_end( + self, + ctx: Context, + aggregator: Aggregator, + fed_args: FedArguments, + args: TrainingArguments): pass - def on_federation(self, ctx: Context, aggregator: Aggregator, fed_args: FedArguments, args: TrainingArguments): + def on_federation( + self, + ctx: Context, + aggregator: Aggregator, + fed_args: FedArguments, + args: TrainingArguments): pass def train(self): if self.local_mode: - logger.info('Local model is set, skip initializing fed setting & aggregator') + logger.info( + 'Local model is set, skip initializing fed setting & aggregator') return self.aggregator: Aggregator = self.init_aggregator() logger.info('Initialized aggregator Done: {}'.format(self.aggregator)) if self.parameter_alignment: - self._parameter_check_callback.on_train_begin(None, None, None) # only get parameters from clients and align + self._parameter_check_callback.on_train_begin( + None, None, None) # only get parameters from clients and align parameters = self._parameter_check_callback.get_parameters() self._max_aggregation = parameters['max_aggregation'] logger.info('checked parameters are {}'.format(parameters)) else: - logger.warn('If you choose not to use parameter alignment, please make sure that the sever aggregation round matches clients\'') - self._max_aggregation, _ = compute_max_aggregation(self._fed_args, self._args.num_train_epochs, self._args.max_steps, 0, 0) - - self.on_init_end(self.ctx, aggregator=self.aggregator, args=self._args, fed_args=self._fed_args) - self.on_train_begin(self.ctx, aggregator=self.aggregator, args=self._args, fed_args=self._fed_args) + logger.warn( + 'If you choose not to use parameter alignment, please make sure that the sever aggregation round matches clients\'') + self._max_aggregation, _ = compute_max_aggregation( + self._fed_args, self._args.num_train_epochs, self._args.max_steps, 0, 0) + + self.on_init_end( + self.ctx, + aggregator=self.aggregator, + args=self._args, + fed_args=self._fed_args) + self.on_train_begin( + self.ctx, + aggregator=self.aggregator, + args=self._args, + fed_args=self._fed_args) for i in range(self._max_aggregation): - self.on_federation(self.ctx, aggregator=self.aggregator, args=self._args, fed_args=self._fed_args) - self.on_train_end(self.ctx, aggregator=self.aggregator, args=self._args, fed_args=self._fed_args) + self.on_federation( + self.ctx, + aggregator=self.aggregator, + args=self._args, + fed_args=self._fed_args) + self.on_train_end( + self.ctx, + aggregator=self.aggregator, + args=self._args, + fed_args=self._fed_args) def predict(self): # server does not need to predict pass - diff --git a/python/fate/ml/utils/_convergence.py b/python/fate/ml/utils/_convergence.py index f3557e53b2..153d529570 100644 --- a/python/fate/ml/utils/_convergence.py +++ b/python/fate/ml/utils/_convergence.py @@ -39,7 +39,9 @@ def __init__(self, eps): self.pre_loss = None def is_converge(self, loss): - logger.debug("In diff converge function, pre_loss: {}, current_loss: {}".format(self.pre_loss, loss)) + logger.debug( + "In diff converge function, pre_loss: {}, current_loss: {}".format( + self.pre_loss, loss)) converge_flag = False if self.pre_loss is None: @@ -99,4 +101,5 @@ def converge_func_factory(early_stop, tol): elif early_stop == 'abs': return _AbsConverge(tol) else: - raise NotImplementedError("Converge Function method cannot be recognized: {}".format(early_stop)) + raise NotImplementedError( + "Converge Function method cannot be recognized: {}".format(early_stop)) diff --git a/python/fate/ml/utils/_model_param.py b/python/fate/ml/utils/_model_param.py index 2d2840a417..07016973ff 100644 --- a/python/fate/ml/utils/_model_param.py +++ b/python/fate/ml/utils/_model_param.py @@ -27,7 +27,9 @@ def initialize_param(coef_len, **kwargs): elif method == 'ones': return torch.ones((param_len, 1), requires_grad=True) elif method == 'consts': - return torch.full((param_len, 1), float(kwargs["fill_val"]), requires_grad=True) + return torch.full( + (param_len, 1), float( + kwargs["fill_val"]), requires_grad=True) elif method == 'random': return torch.randn((param_len, 1), requires_grad=True) else: diff --git a/python/fate/ml/utils/_optimizer.py b/python/fate/ml/utils/_optimizer.py index c586443523..8db4c132ed 100644 --- a/python/fate/ml/utils/_optimizer.py +++ b/python/fate/ml/utils/_optimizer.py @@ -29,7 +29,8 @@ def __init__(self, method=None, lr_params=None, iters=0): self.lr_scheduler = None def init_scheduler(self, optimizer): - self.lr_scheduler = lr_scheduler_factory(optimizer, self.method, self.lr_params) + self.lr_scheduler = lr_scheduler_factory( + optimizer, self.method, self.lr_params) def step(self): self.lr_scheduler.step() @@ -55,7 +56,13 @@ def get_last_lr(self): class Optimizer(object): - def __init__(self, method=None, penalty=None, alpha=None, optim_param: dict = None, iters: int = 0): + def __init__( + self, + method=None, + penalty=None, + alpha=None, + optim_param: dict = None, + iters: int = 0): self.method = method self.optim_param = optim_param self.iters = iters @@ -67,14 +74,18 @@ def __init__(self, method=None, penalty=None, alpha=None, optim_param: dict = No self.prev_model_parameter = None self.optimizer = None - def init_optimizer(self, model_parameter_length=None, model_parameter=None, dtype=torch.float32): + def init_optimizer( + self, + model_parameter_length=None, + model_parameter=None, + dtype=torch.float32): # @todo: allow group in future if model_parameter_length is not None: - model_parameter = torch.nn.parameter.Parameter(torch.zeros((model_parameter_length, 1), - requires_grad=True, - dtype=dtype)) + model_parameter = torch.nn.parameter.Parameter(torch.zeros( + (model_parameter_length, 1), requires_grad=True, dtype=dtype)) self.model_parameter = model_parameter - self.optimizer = optimizer_factory([model_parameter], self.method, self.optim_param) + self.optimizer = optimizer_factory( + [model_parameter], self.method, self.optim_param) # for regularization # self.alpha = self.optimizer.state_dict()['param_groups'][0]['alpha'] @@ -92,7 +103,7 @@ def step(self, gradient): def get_delta_gradients(self): # logger.info(f"gradient: {self.model_parameter.grad}, prev model parameter: {self.prev_model_parameter}," - # f"delta grad: {self.model_parameter - self.prev_model_parameter}") + # f"delta grad: {self.model_parameter - self.prev_model_parameter}") if self.prev_model_parameter is not None: return self.model_parameter.data - self.prev_model_parameter else: @@ -124,7 +135,10 @@ def load_state_dict(self, dict): self.alpha = dict["alpha"] self.method = dict["method"] self.optim_param = dict["optim_param"] - self.init_optimizer(model_parameter=torch.nn.parameter.Parameter(torch.tensor(dict["model_parameter"]))) + self.init_optimizer( + model_parameter=torch.nn.parameter.Parameter( + torch.tensor( + dict["model_parameter"]))) state_dict = dict["optimizer"] state_all = state_dict['state'].get(0, {}) for k, v in state_all.items(): @@ -143,8 +157,11 @@ def _l1_updator(self, model_weights, gradient, fit_intercept, lr): gradient_without_intercept = gradient coef_ = model_weights - new_weights = torch.sign(coef_ - gradient_without_intercept) * torch.maximum(torch.tensor([0]), torch.abs( - coef_ - gradient_without_intercept) - self.shrinkage_val(lr)) + new_weights = torch.sign( + coef_ - gradient_without_intercept) * torch.maximum( + torch.tensor( + [0]), torch.abs( + coef_ - gradient_without_intercept) - self.shrinkage_val(lr)) if fit_intercept: new_weights = torch.concat((new_weights, model_weights.intercept_)) @@ -155,7 +172,8 @@ def _l1_updator(self, model_weights, gradient, fit_intercept, lr): def add_regular_to_grad(self, grad, model_weights, fit_intercept=False): if self.l2_penalty: if fit_intercept: - weights_sum = torch.concat((model_weights[:-1], torch.tensor([[0]]))) + weights_sum = torch.concat( + (model_weights[:-1], torch.tensor([[0]]))) logger.info(f"grad: {grad}, weights sum: {weights_sum}") new_grad = grad + self.alpha * weights_sum else: @@ -165,9 +183,16 @@ def add_regular_to_grad(self, grad, model_weights, fit_intercept=False): return new_grad - def regularization_update(self, model_weights, grad, fit_intercept, lr, prev_round_weights=None): + def regularization_update( + self, + model_weights, + grad, + fit_intercept, + lr, + prev_round_weights=None): if self.l1_penalty: - model_weights = self._l1_updator(model_weights, grad, fit_intercept, lr) + model_weights = self._l1_updator( + model_weights, grad, fit_intercept, lr) else: model_weights = model_weights - grad """elif self.l2_penalty: @@ -198,7 +223,8 @@ def __l1_loss_norm(self, model_weights): return loss_norm def __l2_loss_norm(self, model_weights): - loss_norm = 0.5 * self.alpha * torch.matmul(model_weights.T, model_weights) + loss_norm = 0.5 * self.alpha * \ + torch.matmul(model_weights.T, model_weights) return loss_norm """def __add_proximal(self, model_weights, prev_round_weights): @@ -245,16 +271,23 @@ def loss_norm(self, model_weights, prev_round_weights=None): raise_overflow_error=delta_s.raise_overflow_error) """ - def update_weights(self, model_weights, grad, fit_intercept, lr, prev_round_weights=None, - has_applied=True): - + def update_weights( + self, + model_weights, + grad, + fit_intercept, + lr, + prev_round_weights=None, + has_applied=True): """if not has_applied: grad = self.add_regular_to_grad(grad, model_weights) delta_grad = self.apply_gradients(grad) else:""" - logger.info(f"before update, model weights: {model_weights}, delta_grad: {grad}") + logger.info( + f"before update, model weights: {model_weights}, delta_grad: {grad}") delta_grad = grad - model_weights = self.regularization_update(model_weights, delta_grad, fit_intercept, lr, prev_round_weights) + model_weights = self.regularization_update( + model_weights, delta_grad, fit_intercept, lr, prev_round_weights) logger.info(f"after update, model weights: {model_weights}") return model_weights @@ -304,17 +337,20 @@ def optimizer_factory(model_parameter, optimizer_type, optim_params): elif optimizer_type == 'sgd': return torch.optim.SGD(model_parameter, **optimizer_params) else: - raise NotImplementedError("Optimize method cannot be recognized: {}".format(optimizer_type)) + raise NotImplementedError( + "Optimize method cannot be recognized: {}".format(optimizer_type)) def lr_scheduler_factory(optimizer, method, scheduler_param): scheduler_method = method if scheduler_method == 'constant': - return torch.optim.lr_scheduler.ConstantLR(optimizer, **scheduler_param) + return torch.optim.lr_scheduler.ConstantLR( + optimizer, **scheduler_param) elif scheduler_method == 'step': return torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_param) elif scheduler_method == 'linear': return torch.optim.lr_scheduler.LinearLR(optimizer, **scheduler_param) else: - raise NotImplementedError(f"Learning rate method cannot be recognized: {scheduler_method}") + raise NotImplementedError( + f"Learning rate method cannot be recognized: {scheduler_method}") diff --git a/python/fate/ml/utils/model_io.py b/python/fate/ml/utils/model_io.py index 406037009b..6d26291cb8 100644 --- a/python/fate/ml/utils/model_io.py +++ b/python/fate/ml/utils/model_io.py @@ -23,6 +23,6 @@ def from_dict(cls, d: dict): else: meta = None return cls(data, meta) - + def __repr__(self) -> str: - return f"{self.__class__.__name__}(data={self.data}, meta={self.meta})" \ No newline at end of file + return f"{self.__class__.__name__}(data={self.data}, meta={self.meta})" diff --git a/python/fate/ml/utils/model_serdes.py b/python/fate/ml/utils/model_serdes.py index 46a5b8413e..a7f75b14a3 100644 --- a/python/fate/ml/utils/model_serdes.py +++ b/python/fate/ml/utils/model_serdes.py @@ -23,7 +23,8 @@ def serialize_models(models): for model_name, buffer_object in models.items(): serialized_string = buffer_object.SerializeToString() pb_name = type(buffer_object).__name__ - json_format_dict = json_format.MessageToDict(buffer_object, including_default_value_fields=True) + json_format_dict = json_format.MessageToDict( + buffer_object, including_default_value_fields=True) serialized_models[model_name] = ( pb_name, diff --git a/python/fate/ml/utils/predict_tools.py b/python/fate/ml/utils/predict_tools.py index 39b199a73d..6c4d48e5b8 100644 --- a/python/fate/ml/utils/predict_tools.py +++ b/python/fate/ml/utils/predict_tools.py @@ -25,7 +25,7 @@ def predict_detail_dict_to_str(result_dict): return "\"" + json.dumps(result_dict).replace("\"", "\'") + "\"" -def add_ids(df: pd.DataFrame, match_id: pd.DataFrame, sample_id:pd.DataFrame): +def add_ids(df: pd.DataFrame, match_id: pd.DataFrame, sample_id: pd.DataFrame): df = pd.concat([df, match_id, sample_id], axis=1) return df @@ -33,16 +33,25 @@ def add_ids(df: pd.DataFrame, match_id: pd.DataFrame, sample_id:pd.DataFrame): def to_fate_df(ctx, sample_id_name, match_id_name, result_df): if LABEL in result_df: - reader = PandasReader(sample_id_name=sample_id_name, match_id_name=match_id_name, label_name=LABEL, dtype="object") + reader = PandasReader( + sample_id_name=sample_id_name, + match_id_name=match_id_name, + label_name=LABEL, + dtype="object") else: - reader = PandasReader(sample_id_name=sample_id_name, match_id_name=match_id_name, dtype="object") + reader = PandasReader( + sample_id_name=sample_id_name, + match_id_name=match_id_name, + dtype="object") data = reader.to_frame(ctx, result_df) return data -def compute_predict_details(dataframe: Union[pd.DataFrame, DataFrame], task_type, classes: list = None, threshold=0.5): - - assert task_type in [BINARY, MULTI, REGRESSION, OTHER], 'task_type must be one of {} as a std task, but got {}'.format([BINARY, MULTI, REGRESSION, OTHER], task_type) +def compute_predict_details( + dataframe: Union[pd.DataFrame, DataFrame], task_type, classes: list = None, threshold=0.5): + + assert task_type in [BINARY, MULTI, REGRESSION, OTHER], 'task_type must be one of {} as a std task, but got {}'.format( + [BINARY, MULTI, REGRESSION, OTHER], task_type) if isinstance(dataframe, DataFrame): df = dataframe.as_pd_df() else: @@ -51,26 +60,36 @@ def compute_predict_details(dataframe: Union[pd.DataFrame, DataFrame], task_type pred = df[PREDICT_SCORE].values if PREDICT_SCORE in df else None if pred is None: raise ValueError('pred score is not found in input dataframe') - + if task_type == BINARY and task_type == MULTI and classes is None: raise ValueError('task_type is binary or multi, but classes is None') - + if task_type == BINARY: if len(classes) == 2: predict_score = np.array(pred) predict_result = (predict_score > threshold).astype(int) - predict_details = [{classes[0]: 1 - float(predict_score[i]), classes[1]: float(predict_score[i])} for i in range(len(predict_score))] + predict_details = [ + { + classes[0]: 1 - + float( + predict_score[i]), + classes[1]: float( + predict_score[i])} for i in range( + len(predict_score))] else: - raise ValueError('task_type is binary, but classes length is not 2: {}'.format(classes)) + raise ValueError( + 'task_type is binary, but classes length is not 2: {}'.format(classes)) - elif task_type == MULTI: + elif task_type == MULTI: if len(classes) > 2: predict_score = np.array([max(i) for i in pred]) predict_result = np.array([np.argmax(i) for i in pred]) - predict_details = [predict_detail_dict_to_str({classes[j]: float(pred[i][j]) for j in range(len(classes))}) for i in range(len(pred))] + predict_details = [predict_detail_dict_to_str({classes[j]: float( + pred[i][j]) for j in range(len(classes))}) for i in range(len(pred))] else: - raise ValueError('task_type is multi, but classes length is not greater than 2: {}'.format(classes)) - + raise ValueError( + 'task_type is multi, but classes length is not greater than 2: {}'.format(classes)) + elif task_type == REGRESSION: # regression task predict_score = np.array(pred) @@ -80,12 +99,17 @@ def compute_predict_details(dataframe: Union[pd.DataFrame, DataFrame], task_type df[PREDICT_RESULT] = predict_result df[PREDICT_DETAIL] = predict_details if task_type == MULTI: - df[PREDICT_SCORE] = predict_score + df[PREDICT_SCORE] = predict_score return df -def std_output_df(task_type, pred: np.array, label: np.array=None, threshold=0.5, classes: list = None): +def std_output_df( + task_type, + pred: np.array, + label: np.array = None, + threshold=0.5, + classes: list = None): df = pd.DataFrame() if len(pred.shape) == 1: @@ -96,8 +120,10 @@ def std_output_df(task_type, pred: np.array, label: np.array=None, threshold=0.5 else: df[PREDICT_SCORE] = np.array(pred).tolist() else: - raise ValueError('This is not a FATE std task, pred scores shape are {}'.format(pred.shape)) - + raise ValueError( + 'This is not a FATE std task, pred scores shape are {}'.format( + pred.shape)) + if label is not None: if len(label.shape) == 1: label = label.flatten() @@ -106,7 +132,7 @@ def std_output_df(task_type, pred: np.array, label: np.array=None, threshold=0.5 else: label = label.tolist() df[LABEL] = label - + df = compute_predict_details(df, task_type, classes, threshold) - return df \ No newline at end of file + return df From 97b353f841e444c0bc6a4542b7a6c369ca435898 Mon Sep 17 00:00:00 2001 From: cwj Date: Fri, 14 Jul 2023 10:45:41 +0800 Subject: [PATCH 38/61] Update homo-lr batch size setting Signed-off-by: cwj --- python/fate/components/components/homo_lr.py | 6 +++--- python/fate/ml/glm/homo/lr/client.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/fate/components/components/homo_lr.py b/python/fate/components/components/homo_lr.py index 96b4c00ef6..60a3b6a43c 100644 --- a/python/fate/components/components/homo_lr.py +++ b/python/fate/components/components/homo_lr.py @@ -45,9 +45,9 @@ def train( "refer to torch.optim.lr_scheduler"), epochs: cpn.parameter(type=params.conint(gt=0), default=20, desc="max iteration num"), - batch_size: cpn.parameter(type=params.conint(ge=-1), default=100, - desc="batch size, " - "value less or equals to 0 means full batch"), + batch_size: cpn.parameter(type=params.conint(ge=0), default=None, + desc="batch size, int > 0, if None means full batch" + "non"), optimizer: cpn.parameter(type=params.optimizer_param(), default=params.OptimizerParam(method="sgd", penalty='l2', alpha=1.0, optimizer_params={"lr": 1e-2, "weight_decay": 0})), diff --git a/python/fate/ml/glm/homo/lr/client.py b/python/fate/ml/glm/homo/lr/client.py index 12b061aa0f..8d50f3ad52 100644 --- a/python/fate/ml/glm/homo/lr/client.py +++ b/python/fate/ml/glm/homo/lr/client.py @@ -195,7 +195,7 @@ class HomoLRClient(HomoModule): def __init__( self, epochs: int = 5, - batch_size: int = 32, + batch_size: int = None, optimizer_param={ 'method': 'sgd', 'optimizer_params': { @@ -264,9 +264,9 @@ def __init__( # checkping param assert self.max_iter > 0 and isinstance( self.max_iter, int), "max_iter must be int greater than 0" - if self.batch_size != -1: + if self.batch_size is not None: assert self.batch_size > 0 and isinstance( - self.batch_size, int), "batch_size must be int greater than 0 or -1" + self.batch_size, int), "batch_size must be int greater than 0 or None" assert self.threshold > 0 and self.threshold < 1, "threshold must be float between 0 and 1" def _make_dataset(self, data) -> TableDataset: @@ -343,7 +343,7 @@ def fit(self, ctx: Context, train_data: DataFrame, self._check_labels(unique_label_set, validate_data is not None) - if self.batch_size == -1: + if self.batch_size is None: self.batch_size = len(self.train_set) # prepare loss function @@ -425,7 +425,7 @@ def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: self.predict_set = self._make_dataset(predict_data) if self.trainer is None: batch_size = len( - self.predict_set) if self.batch_size == -1 else self.batch_size + self.predict_set) if self.batch_size is None else self.batch_size train_arg = TrainingArguments( num_train_epochs=self.max_iter, per_device_eval_batch_size=batch_size) From 1c83df053c8090f773366fa8db614a056f517947 Mon Sep 17 00:00:00 2001 From: weiwee Date: Fri, 14 Jul 2023 14:22:06 +0800 Subject: [PATCH 39/61] add hist mock Signed-off-by: weiwee --- python/fate/arch/tensor/inside/__init__.py | 1 + python/fate/arch/tensor/inside/_op_hist.py | 45 ++++++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 python/fate/arch/tensor/inside/_op_hist.py diff --git a/python/fate/arch/tensor/inside/__init__.py b/python/fate/arch/tensor/inside/__init__.py index fdbbfc3879..70d990614c 100644 --- a/python/fate/arch/tensor/inside/__init__.py +++ b/python/fate/arch/tensor/inside/__init__.py @@ -1 +1,2 @@ +from ._op_hist import Hist from ._op_quantile import GKSummary diff --git a/python/fate/arch/tensor/inside/_op_hist.py b/python/fate/arch/tensor/inside/_op_hist.py new file mode 100644 index 0000000000..a2ddb7823f --- /dev/null +++ b/python/fate/arch/tensor/inside/_op_hist.py @@ -0,0 +1,45 @@ +class Hist: + def __init__(self): + self.data = {} + + def update(self, features, labels): + shape_x, shape_y = features.shape + for i in range(shape_x): + for j in range(shape_y): + v = features[i, j] + if j not in self.data: + self.data[j] = {} + if v not in self.data[j]: + self.data[j][v] = labels[i] + else: + self.data[j][v] += labels[i] + + def merge(self, hist): + for k in hist.data: + if k not in self.data: + self.data[k] = hist.data[k] + else: + for kk in hist.data[k]: + if kk not in self.data[k]: + self.data[k][kk] = hist.data[k][kk] + else: + self.data[k][kk] += hist.data[k][kk] + return self + + def cumsum(self): + for k in self.data: + s = 0 + for kk in sorted(self.data[k].keys()): + s += self.data[k][kk] + self.data[k][kk] = s + return self + + +if __name__ == "__main__": + import numpy as np + + hist = Hist() + features = np.array([[1, 0], [0, 1], [2, 1], [2, 0]]) + labels = np.array([0, 1, 0, 0]) + hist.update(features, labels) + print(hist.data) From 1dff90235e4e013c9994ac7eff8692931754bfc7 Mon Sep 17 00:00:00 2001 From: weiwee Date: Fri, 14 Jul 2023 16:28:53 +0800 Subject: [PATCH 40/61] add process watch thread to clean up process create by standalone Signed-off-by: weiwee --- python/fate/arch/_standalone.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/python/fate/arch/_standalone.py b/python/fate/arch/_standalone.py index cbbb6e249e..59c2d0998f 100644 --- a/python/fate/arch/_standalone.py +++ b/python/fate/arch/_standalone.py @@ -20,6 +20,8 @@ import os import pickle as c_pickle import shutil +import signal +import threading import time import uuid from collections.abc import Iterable @@ -60,6 +62,31 @@ LOGGER.debug(f"env STANDALONE_DATA_PATH is not set, using {_data_dir} as data dir") +def _watch_thread_react_to_parent_die(ppid): + """ + this function is used to watch parent process, if parent process is dead, then kill self + the trick is to use os.kill(ppid, 0) to check if parent process is alive periodically + and if parent process is dead, then kill self + + Note: this trick is modified from the answer by aaron: https://stackoverflow.com/a/71369760/14697733 + Args: + ppid: parent process id + + """ + pid = os.getpid() + + def f(): + while True: + try: + os.kill(ppid, 0) + except OSError: + os.kill(pid, signal.SIGTERM) + time.sleep(1) + + thread = threading.Thread(target=f, daemon=True) + thread.start() + + # noinspection PyPep8Naming class Table(object): def __init__( @@ -359,7 +386,9 @@ def delete(self, k): class Session(object): def __init__(self, session_id, max_workers=None): self.session_id = session_id - self._pool = Executor(max_workers=max_workers) + self._pool = Executor( + max_workers=max_workers, initializer=_watch_thread_react_to_parent_die, initargs=(os.getpid(),) + ) def __getstate__(self): # session won't be pickled From a879fdc2fd7c31fb6f26062e75c0e29678e2ae48 Mon Sep 17 00:00:00 2001 From: cwj Date: Fri, 14 Jul 2023 16:47:08 +0800 Subject: [PATCH 41/61] Update predict tools & fix nn bug Signed-off-by: cwj --- .../fate/components/components/evaluation.py | 4 +- .../components/components/nn/nn_runner.py | 51 ++++----- .../components/nn/runner/default_runner.py | 6 +- .../components/nn/runner/my_runner.py | 74 ------------- .../components/nn/test/test_default_runner.py | 4 +- python/fate/ml/glm/homo/lr/client.py | 24 +++-- python/fate/ml/nn/dataset/table.py | 8 +- python/fate/ml/utils/predict_tools.py | 102 +++++++++--------- .../fate/ml/utils/test/test_predict_format.py | 49 +++++++++ 9 files changed, 147 insertions(+), 175 deletions(-) delete mode 100644 python/fate/components/components/nn/runner/my_runner.py create mode 100644 python/fate/ml/utils/test/test_predict_format.py diff --git a/python/fate/components/components/evaluation.py b/python/fate/components/components/evaluation.py index 27fc49ea98..252e41beb7 100644 --- a/python/fate/components/components/evaluation.py +++ b/python/fate/components/components/evaluation.py @@ -51,8 +51,8 @@ def evaluation( metrics: cpn.parameter(type=list, default=None, optional=True), predict_column_name: cpn.parameter(type=str, default=None, optional=True, desc="predict data column name, if None(default), will use \ - 'predict_score' when use binary and regression default setting, \ - and use 'predict_result' on multi classification default setting"), + 'predict_score' in the input dataframe when the default setting are binary and regression, \ + and use 'predict_result' if default setting is multi"), label_column_name: cpn.parameter(type=str, default=None, optional=True, desc="label data column namem if None(default), \ will use 'label' in the input dataframe") ): diff --git a/python/fate/components/components/nn/nn_runner.py b/python/fate/components/components/nn/nn_runner.py index 9da71f4ed7..65be3e007a 100644 --- a/python/fate/components/components/nn/nn_runner.py +++ b/python/fate/components/components/nn/nn_runner.py @@ -8,10 +8,10 @@ from transformers.trainer_utils import PredictionOutput import numpy as np from fate.arch.dataframe._dataframe import DataFrame -from fate.arch.dataframe.manager.schema_manager import Schema from fate.components.components.utils import consts import logging -from fate.ml.utils.predict_tools import to_fate_df, std_output_df, add_ids +from fate.ml.utils.predict_tools import to_fate_df, array_to_predict_df +from fate.ml.utils.predict_tools import BINARY, MULTI, REGRESSION, OTHER, LABEL, PREDICT_SCORE logger = logging.getLogger(__name__) @@ -84,7 +84,7 @@ def get_nn_output_dataframe( self, ctx, predictions: Union[np.ndarray, torch.Tensor, DataFrame, PredictionOutput], - labels: Union[np.ndarray, torch.Tensor, DataFrame, PredictionOutput], + labels: Union[np.ndarray, torch.Tensor, DataFrame, PredictionOutput] = None, match_ids: Union[pd.DataFrame, np.ndarray] = None, sample_ids: Union[pd.DataFrame, np.ndarray] = None, match_id_name: str = None, @@ -115,8 +115,7 @@ def get_nn_output_dataframe( DataFrame: A DataFrame that contains the neural network's predictions and the true labels, possibly along with match IDs and sample IDs, formatted according to the specified format. """ # check parameters - assert task_type in ['binary', 'multi', 'regression', - 'others'], f"task_type {task_type} is not supported" + assert task_type in [BINARY, MULTI, REGRESSION, OTHER], f"task_type {task_type} is not supported" assert dataframe_format in [ 'default', 'fate_std'], f"dataframe_format {dataframe_format} is not supported" @@ -127,13 +126,14 @@ def get_nn_output_dataframe( if isinstance(predictions, PredictionOutput): predictions = predictions.predictions - if isinstance(labels, PredictionOutput): - labels = labels.label_ids - - predictions = _convert_to_numpy_array(predictions) - labels = _convert_to_numpy_array(labels) - assert len(predictions) == len( - labels), f"predictions length {len(predictions)} != labels length {len(labels)}" + + if labels is not None: + if isinstance(labels, PredictionOutput): + labels = labels.label_ids + predictions = _convert_to_numpy_array(predictions) + labels = _convert_to_numpy_array(labels) + assert len(predictions) == len( + labels), f"predictions length {len(predictions)} != labels length {len(labels)}" # check match ids if match_ids is not None: @@ -165,26 +165,17 @@ def get_nn_output_dataframe( sample_id_name, str), f"sample_id_name must be str, but got {type(sample_id_name)}" if dataframe_format == 'default' or ( - dataframe_format == 'fate_std' and task_type == 'others'): - df = pd.DataFrame({'label': labels.to_list(), - 'predict': predictions.to_list(), - match_id_name: match_ids.to_list(), - sample_id_name: sample_ids.to_list()}) + dataframe_format == 'fate_std' and task_type == OTHER): + df = pd.DataFrame() + if labels is not None: + df[LABEL] = labels.to_list() + df[PREDICT_SCORE] = predictions.to_list() + df[match_id_name] = match_ids.flatten() + df[sample_id_name] = sample_ids.flatten() df = to_fate_df(ctx, sample_id_name, match_id_name, df) return df - elif dataframe_format == 'fate_std' and task_type in ['binary', 'multi', 'regression']: - df = std_output_df( - task_type, - predictions, - labels, - threshold, - classes) - match_id_df = pd.DataFrame() - match_id_df[match_id_name] = match_ids - sample_id_df = pd.DataFrame() - sample_id_df[sample_id_name] = sample_ids - df = add_ids(df, match_id_df, sample_id_df) - df = to_fate_df(ctx, sample_id_name, match_id_name, df) + elif dataframe_format == 'fate_std' and task_type in [BINARY, MULTI, REGRESSION]: + df = array_to_predict_df(ctx, task_type, predictions, match_ids, sample_ids, match_id_name, sample_id_name, labels, threshold, classes) return df def train(self, diff --git a/python/fate/components/components/nn/runner/default_runner.py b/python/fate/components/components/nn/runner/default_runner.py index 6df95203b3..56c9c514c9 100644 --- a/python/fate/components/components/nn/runner/default_runner.py +++ b/python/fate/components/components/nn/runner/default_runner.py @@ -237,6 +237,7 @@ def client_setup( if output_dir is None: output_dir = './' + resume_path = None if saved_model is not None: model_dict = load_model_dict_from_path(saved_model) model.load_state_dict(model_dict) @@ -245,9 +246,6 @@ def client_setup( resume_path = saved_model logger.info( f"checkpoint detected, resume_path set to {resume_path}") - else: - resume_path = None - # load optimizer optimizer_loader = Loader.from_dict(self.optimizer_conf) optimizer_ = optimizer_loader.load_item() @@ -369,7 +367,7 @@ def predict(self, rs_df = self.get_nn_output_dataframe( self.get_context(), pred_rs.predictions, - pred_rs.label_ids, + pred_rs.label_ids if hasattr(pred_rs, 'label_ids') else None, match_ids, sample_ids, match_id_name=match_id_name, diff --git a/python/fate/components/components/nn/runner/my_runner.py b/python/fate/components/components/nn/runner/my_runner.py deleted file mode 100644 index 63fdc047ae..0000000000 --- a/python/fate/components/components/nn/runner/my_runner.py +++ /dev/null @@ -1,74 +0,0 @@ -from typing import Optional, Union -import torch as t -from fate.ml.nn.algo.homo.fedavg import FedAVGArguments, TrainingArguments, FedAVGCLient, FedAVGServer -from fate.components.components.nn.nn_runner import NNInput, NNOutput, NNRunner -from torch.utils.data import TensorDataset - - -class MyRunner(NNRunner): - - def __init__( - self, - in_feat=30, - epoch=10, - learning_rate=0.01, - batch_size=32) -> None: - super().__init__() - self.in_feat = in_feat - self.epoch = epoch - self.learning_rate = learning_rate - self.batch_size = batch_size - - def setup(self, df=None): - - ctx = self.get_context() - - if self.is_client(): - - df = df.drop(columns=['id', 'sample_id']) - X = df.drop(columns=['y']).values - y = df['y'].values - X_tensor = t.tensor(X, dtype=t.float32) - y_tensor = t.tensor(y, dtype=t.float32) - dataset = TensorDataset(X_tensor, y_tensor) - loss_fn = t.nn.BCELoss() - - model = t.nn.Sequential( - t.nn.Linear(self.in_feat, 10), - t.nn.ReLU(), - t.nn.Linear(10, 1), - t.nn.Sigmoid() - ) - - optimizer = t.optim.Adam(model.parameters(), lr=self.learning_rate) - - train_arg = TrainingArguments( - num_train_epochs=self.epoch, - per_device_train_batch_size=self.batch_size, - disable_tqdm=False) - - fed_arg = FedAVGArguments() - - return FedAVGCLient(ctx=ctx, model=model, optimizer=optimizer, training_args=train_arg, - fed_args=fed_arg, train_set=dataset, loss_fn=loss_fn), dataset - - elif self.is_server(): - return FedAVGServer(ctx=ctx) - - def train(self, input_data: NNInput = None): - if self.is_client(): - df = input_data.get('train_data') - trainer, _ = self.setup(df) - elif self.is_server(): - trainer = self.setup() - - trainer.train() - - def predict(self, input_data: NNInput = None): - - if self.is_client(): - df = input_data.get('test_data') - trainer, ds = self.setup(df) - trainer.set_local_mode() - pred_rs = trainer.predict(ds) - print('pred rs is {}'.format(pred_rs)) diff --git a/python/fate/components/components/nn/test/test_default_runner.py b/python/fate/components/components/nn/test/test_default_runner.py index 42c3a14bc3..badd113d77 100644 --- a/python/fate/components/components/nn/test/test_default_runner.py +++ b/python/fate/components/components/nn/test/test_default_runner.py @@ -12,6 +12,7 @@ from fate.arch.dataframe import PandasReader import logging from fate.components.core import GUEST +from fate.ml.utils.predict_tools import predict_detail_dict_to_str # Get the root logger logger = logging.getLogger() @@ -35,10 +36,11 @@ reader = PandasReader( sample_id_name='sample_id', match_id_name="id", - label_name="y", + # label_name="y", dtype="object") data = reader.to_frame(ctx, df) + runner_conf = get_config_of_default_runner( algo='fedavg', model=Sequential( diff --git a/python/fate/ml/glm/homo/lr/client.py b/python/fate/ml/glm/homo/lr/client.py index 8d50f3ad52..98c998a379 100644 --- a/python/fate/ml/glm/homo/lr/client.py +++ b/python/fate/ml/glm/homo/lr/client.py @@ -10,7 +10,7 @@ from transformers import default_data_collator import functools import tempfile -from fate.ml.utils.predict_tools import std_output_df, add_ids, to_fate_df +from fate.ml.utils.predict_tools import array_to_predict_df from fate.ml.utils.predict_tools import MULTI, BINARY from fate.ml.nn.dataset.table import TableDataset from fate.ml.utils._optimizer import optimizer_factory, lr_scheduler_factory @@ -276,6 +276,7 @@ def _make_dataset(self, data) -> TableDataset: def _make_output_df( self, + ctx, predict_rs, data: TableDataset, threshold: float): @@ -283,13 +284,20 @@ def _make_output_df( if len(classes) == 1: # binary: classes = [0, 1] task_type = BINARY if len(classes) == 2 else MULTI - out_df = std_output_df( + + out_df = array_to_predict_df( + ctx, task_type, predict_rs.predictions, - predict_rs.label_ids, + match_ids=data.get_match_ids(), + sample_ids=data.get_sample_ids(), + match_id_name=data.get_match_id_name(), + sample_id_name=data.get_sample_id_name(), + label=predict_rs.label_ids, threshold=threshold, - classes=classes) - out_df = add_ids(out_df, data.get_match_ids(), data.get_sample_ids()) + classes=classes + ) + return out_df def _check_labels(self, label_set, has_validate=False): @@ -441,10 +449,8 @@ def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: trainer = self.trainer predict_rs = trainer.predict(self.predict_set) predict_out_df = self._make_output_df( - predict_rs, self.predict_set, self.threshold) - match_id_name = self.predict_set.get_match_ids().columns[0] - sample_id_name = self.predict_set.get_sample_ids().columns[0] - return to_fate_df(ctx, match_id_name, sample_id_name, predict_out_df) + ctx, predict_rs, self.predict_set, self.threshold) + return predict_out_df def get_model(self) -> ModelIO: param = {} diff --git a/python/fate/ml/nn/dataset/table.py b/python/fate/ml/nn/dataset/table.py index 52198c5259..ec6cc6cb9e 100644 --- a/python/fate/ml/nn/dataset/table.py +++ b/python/fate/ml/nn/dataset/table.py @@ -219,11 +219,11 @@ def get_classes(self): raise ValueError( "no label found, please check if self.label is set") - def get_sample_ids(self) -> pd.DataFrame: - return self.sample_ids + def get_sample_ids(self) -> np.ndarray: + return self.sample_ids.values - def get_match_ids(self) -> pd.DataFrame: - return self.match_ids + def get_match_ids(self) -> np.ndarray: + return self.match_ids.values def get_sample_id_name(self) -> str: if self.sample_ids is not None and isinstance( diff --git a/python/fate/ml/utils/predict_tools.py b/python/fate/ml/utils/predict_tools.py index 6c4d48e5b8..da7f8a982d 100644 --- a/python/fate/ml/utils/predict_tools.py +++ b/python/fate/ml/utils/predict_tools.py @@ -4,15 +4,18 @@ import numpy as np from typing import Union from fate.arch.dataframe import DataFrame +from typing import Literal - +# DATA SET COLUMNS TRAIN_SET = 'train_set' VALIDATE_SET = 'validate_set' TEST_SET = 'test_set' -LABEL = "label" + +# PREDICT RESULT COLUMNS PREDICT_RESULT = "predict_result" PREDICT_SCORE = "predict_score" PREDICT_DETAIL = "predict_detail" +LABEL = "label" # TASK TYPE BINARY = 'binary' @@ -30,7 +33,7 @@ def add_ids(df: pd.DataFrame, match_id: pd.DataFrame, sample_id: pd.DataFrame): return df -def to_fate_df(ctx, sample_id_name, match_id_name, result_df): +def to_fate_df(ctx, sample_id_name, match_id_name, result_df: pd.DataFrame): if LABEL in result_df: reader = PandasReader( @@ -47,66 +50,60 @@ def to_fate_df(ctx, sample_id_name, match_id_name, result_df): return data -def compute_predict_details( - dataframe: Union[pd.DataFrame, DataFrame], task_type, classes: list = None, threshold=0.5): +def compute_predict_details(dataframe: DataFrame, task_type: Literal['binary', 'multi', 'regression'], classes: list = None, threshold=0.5): assert task_type in [BINARY, MULTI, REGRESSION, OTHER], 'task_type must be one of {} as a std task, but got {}'.format( [BINARY, MULTI, REGRESSION, OTHER], task_type) - if isinstance(dataframe, DataFrame): - df = dataframe.as_pd_df() - else: - df = dataframe - - pred = df[PREDICT_SCORE].values if PREDICT_SCORE in df else None - if pred is None: - raise ValueError('pred score is not found in input dataframe') - - if task_type == BINARY and task_type == MULTI and classes is None: - raise ValueError('task_type is binary or multi, but classes is None') - + + assert threshold >= 0 and threshold <= 1, 'threshold must be float in [0, 1], but got {}'.format(threshold) + + if not isinstance(dataframe, DataFrame): + raise ValueError('dataframe must be a fate DataFrame, but got {}'.format(type(dataframe))) + + assert PREDICT_SCORE in dataframe.schema.columns, 'column {} is not found in input dataframe'.format(PREDICT_SCORE) + + if task_type == BINARY and task_type == MULTI: + if classes is None or (not isinstance(classes, list) and len(classes) < 2): + raise ValueError('task_type is binary or multi, but classes is None, or classes length is less than 2') + if task_type == BINARY: if len(classes) == 2: - predict_score = np.array(pred) - predict_result = (predict_score > threshold).astype(int) - predict_details = [ - { - classes[0]: 1 - - float( - predict_score[i]), - classes[1]: float( - predict_score[i])} for i in range( - len(predict_score))] + neg_class, pos_class = classes[0], classes[1] + dataframe[[PREDICT_RESULT, PREDICT_DETAIL]] = dataframe.apply_row( \ + lambda v: [int(v[PREDICT_SCORE] > threshold), predict_detail_dict_to_str({neg_class: 1 - v[PREDICT_SCORE], pos_class: v[PREDICT_SCORE]})], + enable_type_align_checking=False) else: raise ValueError( 'task_type is binary, but classes length is not 2: {}'.format(classes)) + + elif task_type == REGRESSION: + dataframe[[PREDICT_RESULT, PREDICT_DETAIL]] = dataframe.apply_row( \ + lambda v: [v[PREDICT_SCORE], predict_detail_dict_to_str({PREDICT_SCORE: v[PREDICT_SCORE]})], enable_type_align_checking=False) elif task_type == MULTI: - if len(classes) > 2: - predict_score = np.array([max(i) for i in pred]) - predict_result = np.array([np.argmax(i) for i in pred]) - predict_details = [predict_detail_dict_to_str({classes[j]: float( - pred[i][j]) for j in range(len(classes))}) for i in range(len(pred))] - else: - raise ValueError( - 'task_type is multi, but classes length is not greater than 2: {}'.format(classes)) - elif task_type == REGRESSION: - # regression task - predict_score = np.array(pred) - predict_result = np.array(pred) - predict_details = [{LABEL: float(pred[i])} for i in range(len(pred))] + def handle_multi(v: pd.Series): + predict_result = np.argmax(v[PREDICT_SCORE]) + assert len(v[PREDICT_SCORE]) == len(classes), 'predict score length is not equal to classes length,\ + predict score is {}, but classes are {}, please check the data you provided'.format(v[PREDICT_SCORE], classes) + predict_details = {classes[j]: v[PREDICT_SCORE][j] for j in range(len(classes))} + return [predict_result, predict_detail_dict_to_str(predict_details)] - df[PREDICT_RESULT] = predict_result - df[PREDICT_DETAIL] = predict_details - if task_type == MULTI: - df[PREDICT_SCORE] = predict_score + dataframe[[PREDICT_RESULT, PREDICT_DETAIL]] = dataframe.apply_row(handle_multi, enable_type_align_checking=False) + predict_score = dataframe[PREDICT_SCORE].apply_row(lambda v: max(v[PREDICT_SCORE])) + dataframe[PREDICT_SCORE] = predict_score - return df + return dataframe -def std_output_df( - task_type, - pred: np.array, +def array_to_predict_df( + ctx, + task_type: Literal['binary', 'multi', 'regression'], + pred: np.ndarray, + match_ids: np.ndarray, + sample_ids: np.ndarray, + match_id_name: str, + sample_id_name: str, label: np.array = None, threshold=0.5, classes: list = None): @@ -114,7 +111,7 @@ def std_output_df( df = pd.DataFrame() if len(pred.shape) == 1: df[PREDICT_SCORE] = np.array(pred) - if len(pred.shape) == 2: + elif len(pred.shape) == 2: if pred.shape[1] == 1: df[PREDICT_SCORE] = np.array(pred).flatten() else: @@ -133,6 +130,9 @@ def std_output_df( label = label.tolist() df[LABEL] = label - df = compute_predict_details(df, task_type, classes, threshold) + df[sample_id_name] = sample_ids.flatten() + df[match_id_name] = match_ids.flatten() + fate_df = to_fate_df(ctx, sample_id_name, match_id_name, df) + predict_result = compute_predict_details(fate_df, task_type, classes, threshold) - return df + return predict_result diff --git a/python/fate/ml/utils/test/test_predict_format.py b/python/fate/ml/utils/test/test_predict_format.py new file mode 100644 index 0000000000..8281388ad2 --- /dev/null +++ b/python/fate/ml/utils/test/test_predict_format.py @@ -0,0 +1,49 @@ +from fate.arch import Context +from fate.arch.computing.standalone import CSession +from fate.arch.context import Context +from fate.arch.federation.standalone import StandaloneFederation +import pandas as pd +from fate.ml.utils.predict_tools import compute_predict_details, PREDICT_SCORE, LABEL, BINARY, REGRESSION, MULTI +from fate.arch.dataframe import PandasReader +import numpy as np + + +computing = CSession() +ctx = Context("guest", computing=computing, federation=StandaloneFederation( + computing, "fed", ("guest", 10000), [("host", 9999)]), ) + +df = pd.DataFrame() +df['id'] = [i for i in range(50)] +df['sample_id'] = [i for i in range(len(df))] +df[PREDICT_SCORE] = [np.random.random(1)[0] for i in range(len(df))] +df[LABEL] = [np.random.randint(0, 2) for i in range(len(df))] + +no_label_df = df.drop([LABEL], axis=1) + + +df_reg = pd.DataFrame() +df_reg['id'] = [i for i in range(50)] +df_reg['sample_id'] = [i for i in range(len(df_reg))] +df_reg[PREDICT_SCORE] = [np.random.random(1)[0] * 10 for i in range(len(df_reg))] +df_reg[LABEL] = [np.random.random(1)[0] * 10 for i in range(len(df_reg))] + +df_multi = pd.DataFrame() +df_multi['id'] = [i for i in range(50)] +df_multi['sample_id'] = [i for i in range(len(df_multi))] +df_multi[PREDICT_SCORE] = [[float(np.random.random(1)[0]) for i in range(3)] for i in range(len(df_multi))] +df_multi[LABEL] = [np.random.randint(0, 3) for i in range(len(df_multi))] + +reader = PandasReader( + sample_id_name='sample_id', + match_id_name="id", + dtype="object") +data = reader.to_frame(ctx, df) +data_2 = reader.to_frame(ctx, no_label_df) +data_3 = reader.to_frame(ctx, df_reg) +data_4 = reader.to_frame(ctx, df_multi) + + +rs = compute_predict_details(data, BINARY, classes=[0, 1], threshold=0.8) +rs_2 = compute_predict_details(data_2, BINARY, classes=[0, 1], threshold=0.3) +rs_3 = compute_predict_details(data_3, REGRESSION) +rs_4 = compute_predict_details(data_4, MULTI, classes=[0, 1, 2]) \ No newline at end of file From 1540dc1d80c936af64e579cecde5702bc98e76db Mon Sep 17 00:00:00 2001 From: weiwee Date: Fri, 14 Jul 2023 17:37:02 +0800 Subject: [PATCH 42/61] add optional ranks for dh aggregator Signed-off-by: weiwee --- python/fate/arch/protocol/_dh.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/fate/arch/protocol/_dh.py b/python/fate/arch/protocol/_dh.py index b834e1c907..51391ea69c 100644 --- a/python/fate/arch/protocol/_dh.py +++ b/python/fate/arch/protocol/_dh.py @@ -54,11 +54,14 @@ def __init__(self, ranks, prefix: typing.Optional[str] = None): self.prefix = prefix self.ranks = ranks - def secure_aggregate(self, ctx: Context): + def secure_aggregate(self, ctx: Context, ranks: typing.Optional[int] = None): mix_aggregator = MixAggregate() aggregated_weight = 0.0 has_weight = False - for rank in self.ranks: + + if ranks is None: + ranks = self.ranks + for rank in ranks: mix_arrays, weight = ctx.parties[rank].get(self._get_name(self._send_name)) mix_aggregator.aggregate(mix_arrays) if weight is not None: @@ -67,5 +70,5 @@ def secure_aggregate(self, ctx: Context): if not has_weight: aggregated_weight = None aggregated = mix_aggregator.finalize(aggregated_weight) - for rank in self.ranks: + for rank in ranks: ctx.parties[rank].put(self._get_name(self._recv_name), aggregated) From ff7f2c20b13c81176a018ff60d58a5f30e446a9c Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Fri, 14 Jul 2023 17:43:06 +0800 Subject: [PATCH 43/61] coordinated lr & linr support warm start(#4659) Signed-off-by: Yu Wu --- .../components/components/coordinated_linr.py | 57 +++++-- .../components/components/coordinated_lr.py | 141 +++++++++++------- .../ml/glm/hetero/coordinated_linr/arbiter.py | 40 ++--- .../ml/glm/hetero/coordinated_linr/guest.py | 51 ++++--- .../ml/glm/hetero/coordinated_linr/host.py | 51 ++++--- .../ml/glm/hetero/coordinated_lr/arbiter.py | 74 +++++---- .../ml/glm/hetero/coordinated_lr/guest.py | 86 +++++++---- .../fate/ml/glm/hetero/coordinated_lr/host.py | 72 +++++---- python/fate/ml/utils/_model_param.py | 19 +++ python/fate/ml/utils/_optimizer.py | 26 ++-- 10 files changed, 392 insertions(+), 225 deletions(-) diff --git a/python/fate/components/components/coordinated_linr.py b/python/fate/components/components/coordinated_linr.py index 9a83207d95..9741c0ae1d 100644 --- a/python/fate/components/components/coordinated_linr.py +++ b/python/fate/components/components/coordinated_linr.py @@ -59,6 +59,8 @@ def train( desc="Model param init setting."), train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), output_model: cpn.json_model_output(roles=[GUEST, HOST, ARBITER]), + warm_start_model: cpn.json_model_input(roles=[GUEST, HOST, ARBITER], optional=True), + ): logger.info(f"enter coordinated linr train") # temp code start @@ -69,15 +71,16 @@ def train( if role.is_guest: train_guest( ctx, train_data, validate_data, train_output_data, output_model, epochs, - batch_size, optimizer, learning_rate_scheduler, init_param + batch_size, optimizer, learning_rate_scheduler, init_param, warm_start_model ) elif role.is_host: train_host( ctx, train_data, validate_data, train_output_data, output_model, epochs, - batch_size, optimizer, learning_rate_scheduler, init_param + batch_size, optimizer, learning_rate_scheduler, init_param, warm_start_model ) elif role.is_arbiter: - train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer, learning_rate_scheduler, output_model) + train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer, learning_rate_scheduler, output_model, + warm_start_model) @coordinated_linr.predict() @@ -204,12 +207,19 @@ def cross_validation( def train_guest(ctx, train_data, validate_data, train_output_data, output_model, epochs, - batch_size, optimizer_param, learning_rate_param, init_param): + batch_size, optimizer_param, learning_rate_param, init_param, input_model): + if input_model is not None: + logger.info(f"warm start model provided") + model = input_model.read() + module = CoordinatedLinRModuleGuest.from_model(model) + module.epochs = epochs + module.batch_size = batch_size + else: + module = CoordinatedLinRModuleGuest(epochs=epochs, batch_size=batch_size, + optimizer_param=optimizer_param, learning_rate_param=learning_rate_param, + init_param=init_param) logger.info(f"coordinated linr guest start train") sub_ctx = ctx.sub_ctx("train") - module = CoordinatedLinRModuleGuest(epochs=epochs, batch_size=batch_size, - optimizer_param=optimizer_param, learning_rate_param=learning_rate_param, - init_param=init_param) train_data = train_data.read() if validate_data is not None: validate_data = validate_data.read() @@ -224,6 +234,7 @@ def train_guest(ctx, train_data, validate_data, train_output_data, output_model, predict_result = transform_to_predict_result(train_data, predict_score, data_type="train") if validate_data is not None: + sub_ctx = ctx.sub_ctx("validate_predict") predict_score = module.predict(sub_ctx, validate_data) validate_predict_result = transform_to_predict_result(validate_data, predict_score, data_type="validate") @@ -232,12 +243,20 @@ def train_guest(ctx, train_data, validate_data, train_output_data, output_model, def train_host(ctx, train_data, validate_data, train_output_data, output_model, epochs, batch_size, - optimizer_param, learning_rate_param, init_param): + optimizer_param, learning_rate_param, init_param, input_model): + if input_model is not None: + logger.info(f"warm start model provided") + model = input_model.read() + module = CoordinatedLinRModuleHost.from_model(model) + module.epochs = epochs + module.batch_size = batch_size + else: + module = CoordinatedLinRModuleHost(epochs=epochs, batch_size=batch_size, + optimizer_param=optimizer_param, learning_rate_param=learning_rate_param, + init_param=init_param) logger.info(f"coordinated linr host start train") sub_ctx = ctx.sub_ctx("train") - module = CoordinatedLinRModuleHost(epochs=epochs, batch_size=batch_size, - optimizer_param=optimizer_param, learning_rate_param=learning_rate_param, - init_param=init_param) + train_data = train_data.read() if validate_data is not None: validate_data = validate_data.read() @@ -249,17 +268,25 @@ def train_host(ctx, train_data, validate_data, train_output_data, output_model, sub_ctx = ctx.sub_ctx("predict") module.predict(sub_ctx, train_data) if validate_data is not None: + sub_ctx = ctx.sub_ctx("validate_predict") module.predict(sub_ctx, validate_data) def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, - learning_rate_param, output_model): + learning_rate_param, output_model, input_model): + if input_model is not None: + logger.info(f"warm start model provided") + model = input_model.read() + module = CoordinatedLinRModuleArbiter.from_model(model) + module.epochs = epochs + module.batch_size = batch_size + else: + module = CoordinatedLinRModuleArbiter(epochs=epochs, early_stop=early_stop, tol=tol, batch_size=batch_size, + optimizer_param=optimizer_param, learning_rate_param=learning_rate_param, + ) logger.info(f"coordinated linr arbiter start train") sub_ctx = ctx.sub_ctx("train") - module = CoordinatedLinRModuleArbiter(epochs=epochs, early_stop=early_stop, tol=tol, batch_size=batch_size, - optimizer_param=optimizer_param, learning_rate_param=learning_rate_param, - ) module.fit(sub_ctx) model = module.get_model() diff --git a/python/fate/components/components/coordinated_lr.py b/python/fate/components/components/coordinated_lr.py index 02e0b4c03c..cd4bf473e2 100644 --- a/python/fate/components/components/coordinated_lr.py +++ b/python/fate/components/components/coordinated_lr.py @@ -60,16 +60,17 @@ def train( default="diff", desc="early stopping criterion, choose from {weight_diff, diff, abs, val_metrics}", ), - init_param: cpn.parameter( - type=params.init_param(), - default=params.InitParam(method="zeros", fit_intercept=True), - desc="Model param init setting.", - ), - threshold: cpn.parameter( - type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data" - ), - train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), - output_model: cpn.json_model_output(roles=[GUEST, HOST, ARBITER]), + init_param: cpn.parameter( + type=params.init_param(), + default=params.InitParam(method="zeros", fit_intercept=True), + desc="Model param init setting.", + ), + threshold: cpn.parameter( + type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data" + ), + train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), + output_model: cpn.json_model_output(roles=[GUEST, HOST, ARBITER]), + warm_start_model: cpn.json_model_input(roles=[GUEST, HOST, ARBITER], optional=True), ): logger.info(f"enter coordinated lr train") # temp code start @@ -77,6 +78,7 @@ def train( learning_rate_scheduler = learning_rate_scheduler.dict() init_param = init_param.dict() # temp code end + if role.is_guest: train_guest( ctx, @@ -90,6 +92,7 @@ def train( learning_rate_scheduler, init_param, threshold, + warm_start_model ) elif role.is_host: train_host( @@ -103,9 +106,17 @@ def train( optimizer, learning_rate_scheduler, init_param, + warm_start_model ) elif role.is_arbiter: - train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer, learning_rate_scheduler, output_model) + train_arbiter(ctx, + epochs, + early_stop, + tol, batch_size, + optimizer, + learning_rate_scheduler, + output_model, + warm_start_model) @coordinated_lr.predict() @@ -242,28 +253,34 @@ def train_guest( ctx, train_data, validate_data, - train_output_data, - output_model, - epochs, - batch_size, - optimizer_param, - learning_rate_param, - init_param, - threshold, + train_output_data, + output_model, + epochs, + batch_size, + optimizer_param, + learning_rate_param, + init_param, + threshold, + input_model ): - from fate.arch.dataframe import DataFrame - + if input_model is not None: + logger.info(f"warm start model provided") + model = input_model.read() + module = CoordinatedLRModuleGuest.from_model(model) + module.epochs = epochs + module.batch_size = batch_size + else: + module = CoordinatedLRModuleGuest( + epochs=epochs, + batch_size=batch_size, + optimizer_param=optimizer_param, + learning_rate_param=learning_rate_param, + init_param=init_param, + threshold=threshold, + ) # optimizer = optimizer_factory(optimizer_param) logger.info(f"coordinated lr guest start train") sub_ctx = ctx.sub_ctx("train") - module = CoordinatedLRModuleGuest( - epochs=epochs, - batch_size=batch_size, - optimizer_param=optimizer_param, - learning_rate_param=learning_rate_param, - init_param=init_param, - threshold=threshold, - ) train_data = train_data.read() if validate_data is not None: @@ -281,6 +298,7 @@ def train_guest( train_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, data_type="train" ) if validate_data is not None: + sub_ctx = ctx.sub_ctx("validate_predict") predict_score = module.predict(sub_ctx, validate_data) validate_predict_result = transform_to_predict_result( validate_data, @@ -297,24 +315,32 @@ def train_guest( def train_host( ctx, train_data, - validate_data, - train_output_data, - output_model, - epochs, - batch_size, - optimizer_param, - learning_rate_param, - init_param, + validate_data, + train_output_data, + output_model, + epochs, + batch_size, + optimizer_param, + learning_rate_param, + init_param, + input_model ): + if input_model is not None: + logger.info(f"warm start model provided") + model = input_model.read() + module = CoordinatedLRModuleHost.from_model(model) + module.epochs = epochs + module.batch_size = batch_size + else: + module = CoordinatedLRModuleHost( + epochs=epochs, + batch_size=batch_size, + optimizer_param=optimizer_param, + learning_rate_param=learning_rate_param, + init_param=init_param, + ) logger.info(f"coordinated lr host start train") sub_ctx = ctx.sub_ctx("train") - module = CoordinatedLRModuleHost( - epochs=epochs, - batch_size=batch_size, - optimizer_param=optimizer_param, - learning_rate_param=learning_rate_param, - init_param=init_param, - ) train_data = train_data.read() if validate_data is not None: @@ -327,20 +353,29 @@ def train_host( sub_ctx = ctx.sub_ctx("predict") module.predict(sub_ctx, train_data) if validate_data is not None: + sub_ctx = ctx.sub_ctx("validate_predict") module.predict(sub_ctx, validate_data) -def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_scheduler, output_model): +def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_scheduler, output_model, + input_model): + if input_model is not None: + logger.info(f"warm start model provided") + model = input_model.read() + module = CoordinatedLRModuleArbiter.from_model(model) + module.epochs = epochs + module.batch_size = batch_size + else: + module = CoordinatedLRModuleArbiter( + epochs=epochs, + early_stop=early_stop, + tol=tol, + batch_size=batch_size, + optimizer_param=optimizer_param, + learning_rate_param=learning_rate_scheduler, + ) logger.info(f"coordinated lr arbiter start train") sub_ctx = ctx.sub_ctx("train") - module = CoordinatedLRModuleArbiter( - epochs=epochs, - early_stop=early_stop, - tol=tol, - batch_size=batch_size, - optimizer_param=optimizer_param, - learning_rate_param=learning_rate_scheduler, - ) module.fit(sub_ctx) model = module.get_model() output_model.write(model, metadata={}) diff --git a/python/fate/ml/glm/hetero/coordinated_linr/arbiter.py b/python/fate/ml/glm/hetero/coordinated_linr/arbiter.py index 9799423c12..a4ae6e40d3 100644 --- a/python/fate/ml/glm/hetero/coordinated_linr/arbiter.py +++ b/python/fate/ml/glm/hetero/coordinated_linr/arbiter.py @@ -48,22 +48,23 @@ def __init__( def fit(self, ctx: Context) -> None: encryptor, decryptor = ctx.cipher.phe.keygen(options=dict(key_length=2048)) ctx.hosts("encryptor").put(encryptor) - optimizer = Optimizer( - self.optimizer_param["method"], - self.optimizer_param["penalty"], - self.optimizer_param["alpha"], - self.optimizer_param["optimizer_params"], - ) - lr_scheduler = LRScheduler(self.learning_rate_param["method"], - self.learning_rate_param["scheduler_params"]) - single_estimator = HeteroLinrEstimatorArbiter(epochs=self.epochs, - early_stop=self.early_stop, - tol=self.tol, - batch_size=self.batch_size, - optimizer=optimizer, - learning_rate_scheduler=lr_scheduler) - single_estimator.fit_model(ctx, decryptor) - self.estimator = single_estimator + if self.estimator is None: + optimizer = Optimizer( + self.optimizer_param["method"], + self.optimizer_param["penalty"], + self.optimizer_param["alpha"], + self.optimizer_param["optimizer_params"], + ) + lr_scheduler = LRScheduler(self.learning_rate_param["method"], + self.learning_rate_param["scheduler_params"]) + single_estimator = HeteroLinrEstimatorArbiter(epochs=self.epochs, + early_stop=self.early_stop, + tol=self.tol, + batch_size=self.batch_size, + optimizer=optimizer, + learning_rate_scheduler=lr_scheduler) + self.estimator = single_estimator + self.estimator.fit_model(ctx, decryptor) def get_model(self): return { @@ -76,6 +77,7 @@ def get_model(self): "optimizer_param": self.optimizer_param}, } + @classmethod def from_model(cls, model): linr = CoordinatedLinRModuleArbiter(model["meta"]["epochs"], model["meta"]["early_stop"], @@ -107,7 +109,8 @@ def __init__( self.optimizer = optimizer self.lr_scheduler = learning_rate_scheduler - self.converge_func = converge_func_factory(early_stop, tol) + if early_stop is not None: + self.converge_func = converge_func_factory(early_stop, tol) self.start_epoch = 0 self.end_epoch = -1 self.is_converged = False @@ -121,7 +124,7 @@ def fit_model(self, ctx, decryptor): optimizer_ready = False else: optimizer_ready = True - self.start_epoch = self.end_epoch + 1 + # self.start_epoch = self.end_epoch + 1 for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): iter_loss = None @@ -204,4 +207,5 @@ def restore(self, model): self.lr_scheduler.load_state_dict(model["lr_scheduler"], self.optimizer.optimizer) self.end_epoch = model["end_epoch"] self.is_converged = model["is_converged"] + self.converge_func = converge_func_factory(self.early_stop, self.tol) # self.start_epoch = model["end_epoch"] + 1 diff --git a/python/fate/ml/glm/hetero/coordinated_linr/guest.py b/python/fate/ml/glm/hetero/coordinated_linr/guest.py index 1a50a45ce6..de46955f3f 100644 --- a/python/fate/ml/glm/hetero/coordinated_linr/guest.py +++ b/python/fate/ml/glm/hetero/coordinated_linr/guest.py @@ -19,7 +19,7 @@ from fate.arch import dataframe, Context from fate.ml.abc.module import HeteroModule -from fate.ml.utils._model_param import initialize_param +from fate.ml.utils._model_param import initialize_param, serialize_param, deserialize_param from fate.ml.utils._optimizer import Optimizer, LRScheduler logger = logging.getLogger(__name__) @@ -43,21 +43,22 @@ def __init__( self.estimator = None def fit(self, ctx: Context, train_data, validate_data=None) -> None: - optimizer = Optimizer( - self.optimizer_param["method"], - self.optimizer_param["penalty"], - self.optimizer_param["alpha"], - self.optimizer_param["optimizer_params"], - ) - lr_scheduler = LRScheduler(self.learning_rate_param["method"], - self.learning_rate_param["scheduler_params"]) - estimator = CoordinatedLinREstimatorGuest(epochs=self.epochs, - batch_size=self.batch_size, - optimizer=optimizer, - learning_rate_scheduler=lr_scheduler, - init_param=self.init_param) - estimator.fit_model(ctx, train_data, validate_data) - self.estimator = estimator + if self.estimator is None: + optimizer = Optimizer( + self.optimizer_param["method"], + self.optimizer_param["penalty"], + self.optimizer_param["alpha"], + self.optimizer_param["optimizer_params"], + ) + lr_scheduler = LRScheduler(self.learning_rate_param["method"], + self.learning_rate_param["scheduler_params"]) + estimator = CoordinatedLinREstimatorGuest(epochs=self.epochs, + batch_size=self.batch_size, + optimizer=optimizer, + learning_rate_scheduler=lr_scheduler, + init_param=self.init_param) + self.estimator = estimator + self.estimator.fit_model(ctx, train_data, validate_data) def predict(self, ctx, test_data): prob = self.estimator.predict(ctx, test_data) @@ -76,7 +77,6 @@ def get_model(self): def from_model(cls, model) -> "CoordinatedLinRModuleGuest": linr = CoordinatedLinRModuleGuest(optimizer_param=model["meta"]["optimizer_param"], learning_rate_param=model["meta"]["learning_rate_param"], - epochs=model["meta"]["epochs"], batch_size=model["meta"]["batch_size"], init_param=model["meta"]["init_param"]) estimator = CoordinatedLinREstimatorGuest() @@ -120,8 +120,8 @@ def fit_model(self, ctx, train_data, validate_data=None): batch_loader = dataframe.DataLoader( train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="guest", sync_arbiter=True ) - if self.end_epoch >= 0: - self.start_epoch = self.end_epoch + 1 + # if self.end_epoch >= 0: + # self.start_epoch = self.end_epoch + 1 for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): self.optimizer.set_iters(i) @@ -185,14 +185,15 @@ def predict(self, ctx, test_data): return pred def get_model(self): - w = self.w.tolist() + """w = self.w.tolist() intercept = None if self.init_param.get("fit_intercept"): w = w[:-1] - intercept = w[-1] + intercept = w[-1]""" + param = serialize_param(self.w, self.init_param.get("fit_intercept")) return { - "w": w, - "intercept": intercept, + "param": param, + # "intercept": intercept, "optimizer": self.optimizer.state_dict(), "lr_scheduler": self.lr_scheduler.state_dict(), "end_epoch": self.end_epoch, @@ -201,10 +202,12 @@ def get_model(self): } def restore(self, model): - w = model["w"] + """w = model["w"] if model["fit_intercept"]: w.append(model["intercept"]) self.w = torch.tensor(w) + """ + self.w = deserialize_param(model["param"], model["fit_intercept"]) self.optimizer = Optimizer() self.lr_scheduler = LRScheduler() self.optimizer.load_state_dict(model["optimizer"]) diff --git a/python/fate/ml/glm/hetero/coordinated_linr/host.py b/python/fate/ml/glm/hetero/coordinated_linr/host.py index 5b3807b350..30c4f69480 100644 --- a/python/fate/ml/glm/hetero/coordinated_linr/host.py +++ b/python/fate/ml/glm/hetero/coordinated_linr/host.py @@ -19,7 +19,7 @@ from fate.arch import Context from fate.arch.dataframe import DataLoader from fate.ml.abc.module import HeteroModule -from fate.ml.utils._model_param import initialize_param +from fate.ml.utils._model_param import initialize_param, deserialize_param, serialize_param from fate.ml.utils._optimizer import Optimizer, LRScheduler logger = logging.getLogger(__name__) @@ -45,21 +45,23 @@ def __init__( def fit(self, ctx: Context, train_data, validate_data=None) -> None: encryptor = ctx.arbiter("encryptor").get() - optimizer = Optimizer( - self.optimizer_param["method"], - self.optimizer_param["penalty"], - self.optimizer_param["alpha"], - self.optimizer_param["optimizer_params"], - ) - lr_scheduler = LRScheduler(self.learning_rate_param["method"], - self.learning_rate_param["scheduler_params"]) - estimator = CoordiantedLinREstimatorHost(epochs=self.epochs, - batch_size=self.batch_size, - optimizer=optimizer, - learning_rate_scheduler=lr_scheduler, - init_param=self.init_param) - estimator.fit_model(ctx, encryptor, train_data, validate_data) - self.estimator = estimator + if self.estimator is None: + optimizer = Optimizer( + self.optimizer_param["method"], + self.optimizer_param["penalty"], + self.optimizer_param["alpha"], + self.optimizer_param["optimizer_params"], + ) + lr_scheduler = LRScheduler(self.learning_rate_param["method"], + self.learning_rate_param["scheduler_params"]) + estimator = CoordiantedLinREstimatorHost(epochs=self.epochs, + batch_size=self.batch_size, + optimizer=optimizer, + learning_rate_scheduler=lr_scheduler, + init_param=self.init_param) + self.estimator = estimator + + self.estimator.fit_model(ctx, encryptor, train_data, validate_data) def predict(self, ctx, test_data): self.estimator.predict(ctx, test_data) @@ -116,8 +118,8 @@ def fit_model(self, ctx: Context, encryptor, train_data, validate_data=None) -> w = initialize_param(coef_count, **self.init_param) self.optimizer.init_optimizer(model_parameter_length=w.size()[0]) self.lr_scheduler.init_scheduler(optimizer=self.optimizer.optimizer) - if self.end_epoch >= 0: - self.start_epoch = self.end_epoch + 1 + # if self.end_epoch >= 0: + # self.start_epoch = self.end_epoch + 1 for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): self.optimizer.set_iters(i) logger.info(f"self.optimizer set epoch {i}") @@ -159,8 +161,16 @@ def predict(self, ctx, test_data): ctx.guest.put("h_pred", output) def get_model(self): + """return { + "w": self.w.tolist(), + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "end_epoch": self.end_epoch, + "is_converged": self.is_converged + }""" + param = serialize_param(self.w, False) return { - "w": self.w.tolist(), + "param": param, "optimizer": self.optimizer.state_dict(), "lr_scheduler": self.lr_scheduler.state_dict(), "end_epoch": self.end_epoch, @@ -168,7 +178,8 @@ def get_model(self): } def restore(self, model): - self.w = torch.tensor(model["w"]) + # self.w = torch.tensor(model["w"]) + self.w = deserialize_param(model["param"], False) self.optimizer = Optimizer() self.lr_scheduler = LRScheduler() self.optimizer.load_state_dict(model["optimizer"]) diff --git a/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py b/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py index ab9dea4640..1ad6aa5bce 100644 --- a/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py @@ -42,10 +42,40 @@ def fit(self, ctx: Context) -> None: encryptor, decryptor = ctx.cipher.phe.keygen(options=dict(key_length=2048)) ctx.hosts("encryptor").put(encryptor) label_count = ctx.guest("label_count").get() - if label_count > 2: + if label_count > 2 or self.ovr: self.ovr = True - self.estimator = {} + warm_start = True + if self.estimator is None: + self.estimator = {} + warm_start = False for i, class_ctx in ctx.sub_ctx("class").ctxs_range(label_count): + if not warm_start: + optimizer = Optimizer( + self.optimizer_param["method"], + self.optimizer_param["penalty"], + self.optimizer_param["alpha"], + self.optimizer_param["optimizer_params"], + ) + lr_scheduler = LRScheduler( + self.learning_rate_param["method"], self.learning_rate_param["scheduler_params"] + ) + single_estimator = CoordinatedLREstimatorArbiter( + epochs=self.epochs, + early_stop=self.early_stop, + tol=self.tol, + batch_size=self.batch_size, + optimizer=optimizer, + learning_rate_scheduler=lr_scheduler, + ) + else: + logger.info("estimator is not none, will train with warm start") + single_estimator = self.estimator[i] + single_estimator.epochs = self.epochs + single_estimator.batch_size = self.batch_size + single_estimator.fit_single_model(class_ctx, decryptor) + self.estimator[i] = single_estimator + else: + if self.estimator is None: optimizer = Optimizer( self.optimizer_param["method"], self.optimizer_param["penalty"], @@ -63,26 +93,11 @@ def fit(self, ctx: Context) -> None: optimizer=optimizer, learning_rate_scheduler=lr_scheduler, ) - single_estimator.fit_single_model(class_ctx, decryptor) - self.estimator[i] = single_estimator - else: - optimizer = Optimizer( - self.optimizer_param["method"], - self.optimizer_param["penalty"], - self.optimizer_param["alpha"], - self.optimizer_param["optimizer_params"], - ) - lr_scheduler = LRScheduler( - self.learning_rate_param["method"], self.learning_rate_param["scheduler_params"] - ) - single_estimator = CoordinatedLREstimatorArbiter( - epochs=self.epochs, - early_stop=self.early_stop, - tol=self.tol, - batch_size=self.batch_size, - optimizer=optimizer, - learning_rate_scheduler=lr_scheduler, - ) + else: + logger.info("estimator is not none, will train with warm start") + single_estimator = self.estimator + single_estimator.epochs = self.epochs + single_estimator.batch_size = self.batch_size single_estimator.fit_single_model(ctx, decryptor) self.estimator = single_estimator @@ -106,7 +121,8 @@ def get_model(self): }, } - def from_model(cls, model): + @classmethod + def from_model(cls, model) -> "CoordinatedLRModuleArbiter": lr = CoordinatedLRModuleArbiter( epochs=model["meta"]["epochs"], early_stop=model["meta"]["early_stop"], @@ -116,6 +132,7 @@ def from_model(cls, model): learning_rate_param=model["meta"]["learning_rate_param"], ) all_estimator = model["data"]["estimator"] + if lr.ovr: lr.estimator = {label: CoordinatedLREstimatorArbiter().restore(d) for label, d in all_estimator.items()} else: @@ -136,7 +153,8 @@ def __init__( self.optimizer = optimizer self.lr_scheduler = learning_rate_scheduler - self.converge_func = converge_func_factory(early_stop, tol) + if early_stop is not None: + self.converge_func = converge_func_factory(early_stop, tol) self.start_epoch = 0 self.end_epoch = -1 self.is_converged = False @@ -150,14 +168,13 @@ def fit_single_model(self, ctx: Context, decryptor): optimizer_ready = False else: optimizer_ready = True - self.start_epoch = self.end_epoch + 1 + # self.start_epoch = self.end_epoch + 1 for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): iter_loss = None iter_g = None self.optimizer.set_iters(i) logger.info(f"self.optimizer set epoch {i}") for batch_ctx, _ in iter_ctx.on_batches.ctxs_zip(batch_loader): - g_guest_enc = batch_ctx.guest.get("g_enc") g_guest = decryptor.decrypt(g_guest_enc) size_list = [g_guest.size()[0]] @@ -223,6 +240,8 @@ def get_model(self): "lr_scheduler": self.lr_scheduler.state_dict(), "end_epoch": self.end_epoch, "is_converged": self.is_converged, + "tol": self.tol, + "early_stop": self.early_stop } def restore(self, model): @@ -232,4 +251,7 @@ def restore(self, model): self.lr_scheduler.load_state_dict(model["lr_scheduler"], self.optimizer.optimizer) self.end_epoch = model["end_epoch"] self.is_converged = model["is_converged"] + self.tol = model["tol"] + self.early_stop = model["early_stop"] + self.converge_func = converge_func_factory(self.early_stop, self.tol) # self.start_epoch = model["end_epoch"] + 1 diff --git a/python/fate/ml/glm/hetero/coordinated_lr/guest.py b/python/fate/ml/glm/hetero/coordinated_lr/guest.py index 63096077e9..ddeda64594 100644 --- a/python/fate/ml/glm/hetero/coordinated_lr/guest.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/guest.py @@ -19,7 +19,7 @@ from fate.arch import Context, dataframe from fate.ml.abc.module import HeteroModule -from fate.ml.utils._model_param import initialize_param +from fate.ml.utils._model_param import initialize_param, serialize_param, deserialize_param from fate.ml.utils._optimizer import LRScheduler, Optimizer logger = logging.getLogger(__name__) @@ -52,14 +52,48 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: label_count = train_data_binarized_label.shape[1] ctx.arbiter.put("label_count", label_count) ctx.hosts.put("label_count", label_count) - self.labels = [label_name.split("_")[1] for label_name in train_data_binarized_label.columns] - if label_count > 2: + labels = [label_name.split("_")[1] for label_name in train_data_binarized_label.columns] + if self.labels is None: + self.labels = labels + if label_count > 2 or self.ovr: logger.info(f"OVR data provided, will train OVR models.") self.ovr = True - self.estimator = {} + warm_start = True + if self.estimator is None: + self.estimator = {} + warm_start = False for i, class_ctx in ctx.sub_ctx("class").ctxs_range(label_count): logger.info(f"start train for {i}th class") # optimizer = copy.deepcopy(self.optimizer) + if not warm_start: + optimizer = Optimizer( + self.optimizer_param["method"], + self.optimizer_param["penalty"], + self.optimizer_param["alpha"], + self.optimizer_param["optimizer_params"], + ) + lr_scheduler = LRScheduler(self.learning_rate_param["method"], + self.learning_rate_param["scheduler_params"]) + single_estimator = CoordinatedLREstimatorGuest( + epochs=self.epochs, + batch_size=self.batch_size, + optimizer=optimizer, + learning_rate_scheduler=lr_scheduler, + init_param=self.init_param, + ) + else: + # warm start + logger.info("estimator is not none, will train with warm start") + # single_estimator = self.estimator[self.labels.index(labels[i])] + single_estimator = self.estimator[i] + single_estimator.epochs = self.epochs + single_estimator.batch_size = self.batch_size + train_data.label = train_data_binarized_label[train_data_binarized_label.columns[i]] + single_estimator.fit_single_model(class_ctx, train_data, validate_data) + self.estimator[i] = single_estimator + + else: + if self.estimator is None: optimizer = Optimizer( self.optimizer_param["method"], self.optimizer_param["penalty"], @@ -75,25 +109,11 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: learning_rate_scheduler=lr_scheduler, init_param=self.init_param, ) - train_data.label = train_data_binarized_label[train_data_binarized_label.columns[i]] - single_estimator.fit_single_model(class_ctx, train_data, validate_data) - self.estimator[i] = single_estimator - else: - optimizer = Optimizer( - self.optimizer_param["method"], - self.optimizer_param["penalty"], - self.optimizer_param["alpha"], - self.optimizer_param["optimizer_params"], - ) - lr_scheduler = LRScheduler(self.learning_rate_param["method"], - self.learning_rate_param["scheduler_params"]) - single_estimator = CoordinatedLREstimatorGuest( - epochs=self.epochs, - batch_size=self.batch_size, - optimizer=optimizer, - learning_rate_scheduler=lr_scheduler, - init_param=self.init_param, - ) + else: + logger.info("estimator is not none, will train with warm start") + single_estimator = self.estimator + single_estimator.epochs = self.epochs + single_estimator.batch_size = self.batch_size single_estimator.fit_single_model(ctx, train_data, validate_data) self.estimator = single_estimator train_data.label = original_label @@ -191,8 +211,8 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None): batch_loader = dataframe.DataLoader( train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="guest", sync_arbiter=True ) - if self.end_epoch >= 0: - self.start_epoch = self.end_epoch + 1 + # if self.end_epoch >= 0: + # self.start_epoch = self.end_epoch + 1 for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): self.optimizer.set_iters(i) @@ -256,6 +276,7 @@ def predict(self, ctx, test_data): if self.init_param.get("fit_intercept"): test_data["intercept"] = 1.0 X = test_data.values.as_tensor() + logger.info(f"in predict, w: {self.w}") pred = torch.matmul(X, self.w) for h_pred in ctx.hosts.get("h_pred"): pred += h_pred @@ -263,14 +284,16 @@ def predict(self, ctx, test_data): return pred def get_model(self): - w = self.w.tolist() + """w = self.w.tolist() intercept = None if self.init_param.get("fit_intercept"): w = w[:-1] - intercept = w[-1] + intercept = w[-1]""" + param = serialize_param(self.w, self.init_param.get("fit_intercept")) return { - "w": w, - "intercept": intercept, + # "w": w, + # "intercept": intercept, + "param": param, "optimizer": self.optimizer.state_dict(), "lr_scheduler": self.lr_scheduler.state_dict(), "end_epoch": self.end_epoch, @@ -279,10 +302,11 @@ def get_model(self): } def restore(self, model): - w = model["w"] + """w = model["w"] if model["fit_intercept"]: w.append(model["intercept"]) - self.w = torch.tensor(w) + self.w = torch.tensor(w)""" + self.w = deserialize_param(model["param"], model["fit_intercept"]) self.optimizer = Optimizer() self.lr_scheduler = LRScheduler() self.optimizer.load_state_dict(model["optimizer"]) diff --git a/python/fate/ml/glm/hetero/coordinated_lr/host.py b/python/fate/ml/glm/hetero/coordinated_lr/host.py index 5ff4a8e024..cad6e7cec5 100644 --- a/python/fate/ml/glm/hetero/coordinated_lr/host.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/host.py @@ -19,7 +19,7 @@ from fate.arch import Context from fate.arch.dataframe import DataLoader from fate.ml.abc.module import HeteroModule -from fate.ml.utils._model_param import initialize_param +from fate.ml.utils._model_param import initialize_param, serialize_param, deserialize_param from fate.ml.utils._optimizer import LRScheduler, Optimizer logger = logging.getLogger(__name__) @@ -44,13 +44,42 @@ def __init__( def fit(self, ctx: Context, train_data, validate_data=None) -> None: encryptor = ctx.arbiter("encryptor").get() - self.label_count = ctx.guest("label_count").get() - if self.label_count > 2: + label_count = ctx.guest("label_count").get() + if self.label_count > 2 or self.ovr: self.ovr = True - self.estimator = {} - for i, class_ctx in ctx.sub_ctx("class").ctxs_range(self.label_count): + self.label_count = label_count + warm_start = True + if self.estimator is None: + self.estimator = {} + warm_start = False + for i, class_ctx in ctx.sub_ctx("class").ctxs_range(label_count): # optimizer = copy.deepcopy(self.optimizer) # lr_scheduler = copy.deepcopy(self.lr_scheduler) + if not warm_start: + optimizer = Optimizer( + self.optimizer_param["method"], + self.optimizer_param["penalty"], + self.optimizer_param["alpha"], + self.optimizer_param["optimizer_params"], + ) + lr_scheduler = LRScheduler(self.learning_rate_param["method"], + self.learning_rate_param["scheduler_params"]) + single_estimator = CoordinatedLREstimatorHost( + epochs=self.epochs, + batch_size=self.batch_size, + optimizer=optimizer, + learning_rate_scheduler=lr_scheduler, + init_param=self.init_param, + ) + else: + logger.info("estimator is not none, will train with warm start") + single_estimator = self.estimator[i] + single_estimator.epochs = self.epochs + single_estimator.batch_size = self.batch_size + single_estimator.fit_single_model(class_ctx, encryptor, train_data, validate_data) + self.estimator[i] = single_estimator + else: + if self.estimator is None: optimizer = Optimizer( self.optimizer_param["method"], self.optimizer_param["penalty"], @@ -66,24 +95,11 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: learning_rate_scheduler=lr_scheduler, init_param=self.init_param, ) - single_estimator.fit_single_model(class_ctx, encryptor, train_data, validate_data) - self.estimator[i] = single_estimator - else: - optimizer = Optimizer( - self.optimizer_param["method"], - self.optimizer_param["penalty"], - self.optimizer_param["alpha"], - self.optimizer_param["optimizer_params"], - ) - lr_scheduler = LRScheduler(self.learning_rate_param["method"], - self.learning_rate_param["scheduler_params"]) - single_estimator = CoordinatedLREstimatorHost( - epochs=self.epochs, - batch_size=self.batch_size, - optimizer=optimizer, - learning_rate_scheduler=lr_scheduler, - init_param=self.init_param, - ) + else: + logger.info("estimator is not none, will train with warm start") + single_estimator = self.estimator + single_estimator.epochs = self.epochs + single_estimator.batch_size = self.batch_size single_estimator.fit_single_model(ctx, encryptor, train_data, validate_data) self.estimator = single_estimator @@ -155,8 +171,8 @@ def fit_single_model(self, ctx: Context, encryptor, train_data, validate_data=No self.optimizer.init_optimizer(model_parameter_length=w.size()[0]) self.lr_scheduler.init_scheduler(optimizer=self.optimizer.optimizer) batch_loader = DataLoader(train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="host") - if self.end_epoch >= 0: - self.start_epoch = self.end_epoch + 1 + # if self.end_epoch >= 0: + # self.start_epoch = self.end_epoch + 1 for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): self.optimizer.set_iters(i) logger.info(f"self.optimizer set epoch{i}") @@ -197,8 +213,9 @@ def predict(self, ctx, test_data): ctx.guest.put("h_pred", output) def get_model(self): + param = serialize_param(self.w, False) return { - "w": self.w.tolist(), + "param": param, "optimizer": self.optimizer.state_dict(), "lr_scheduler": self.lr_scheduler.state_dict(), "end_epoch": self.end_epoch, @@ -206,7 +223,8 @@ def get_model(self): } def restore(self, model): - self.w = torch.tensor(model["w"]) + # self.w = torch.tensor(model["w"]) + self.w = deserialize_param(model["param"], False) self.optimizer = Optimizer() self.lr_scheduler = LRScheduler() self.optimizer.load_state_dict(model["optimizer"]) diff --git a/python/fate/ml/utils/_model_param.py b/python/fate/ml/utils/_model_param.py index e57f9cc29a..e9b8821b01 100644 --- a/python/fate/ml/utils/_model_param.py +++ b/python/fate/ml/utils/_model_param.py @@ -32,3 +32,22 @@ def initialize_param(coef_len, **kwargs): return torch.randn((param_len, 1), requires_grad=True) else: raise NotImplementedError(f"Unknown initialization method: {method}") + + +def serialize_param(param, fit_intercept=False): + dtype = str(param.dtype).split(".", -1)[-1] + w = param.tolist() + intercept = None + if fit_intercept: + w = w[:-1] + intercept = w[-1] + return {"coef_": w, "intercept_": intercept, "dtype": dtype} + + +def deserialize_param(param, fit_intercept=False): + w = param["coef_"] + if fit_intercept: + w.append(param["intercept_"]) + dtype = param["dtype"] + w = torch.tensor(w, dtype=getattr(torch, dtype)) + return w diff --git a/python/fate/ml/utils/_optimizer.py b/python/fate/ml/utils/_optimizer.py index 64a70508dc..0ac158d64d 100644 --- a/python/fate/ml/utils/_optimizer.py +++ b/python/fate/ml/utils/_optimizer.py @@ -108,6 +108,7 @@ def state_dict(self): for k, v in state_all.items(): if isinstance(v, torch.Tensor): state_all[k] = v.tolist() + dtype = str(self.model_parameter.dtype).split(".", -1)[-1] return { "l2_penalty": self.l2_penalty, "l1_penalty": self.l1_penalty, @@ -115,22 +116,25 @@ def state_dict(self): "optimizer": optimizer_state_dict, "method": self.method, "optim_param": self.optim_param, - "model_parameter": self.model_parameter.tolist() + "model_parameter": self.model_parameter.tolist(), + "model_parameter_dtype": dtype } - def load_state_dict(self, dict): - self.l2_penalty = dict["l2_penalty"] - self.l1_penalty = dict["l1_penalty"] - self.alpha = dict["alpha"] - self.method = dict["method"] - self.optim_param = dict["optim_param"] - self.init_optimizer(model_parameter=torch.nn.parameter.Parameter(torch.tensor(dict["model_parameter"]))) - state_dict = dict["optimizer"] - state_all = state_dict['state'].get(0, {}) + def load_state_dict(self, state_dict): + self.l2_penalty = state_dict["l2_penalty"] + self.l1_penalty = state_dict["l1_penalty"] + self.alpha = state_dict["alpha"] + self.method = state_dict["method"] + self.optim_param = state_dict["optim_param"] + dtype = state_dict["model_parameter_dtype"] + self.init_optimizer(model_parameter=torch.nn.parameter.Parameter(torch.tensor(state_dict["model_parameter"], + dtype=getattr(torch, dtype)))) + state = state_dict["optimizer"] + state_all = state['state'].get(0, {}) for k, v in state_all.items(): if isinstance(v, list): state_all[k] = torch.tensor(v) - self.optimizer.load_state_dict(dict["optimizer"]) + self.optimizer.load_state_dict(state_dict["optimizer"]) def set_iters(self, new_iters): self.iters = new_iters From f4647ded8da21fd4794cea3ab96136e2bbef15b1 Mon Sep 17 00:00:00 2001 From: cwj Date: Fri, 14 Jul 2023 17:45:44 +0800 Subject: [PATCH 44/61] add test files Signed-off-by: cwj --- python/fate/arch/protocol/_dh.py | 1 - .../ml/aggregator/test/test_fate_utils.py | 52 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 python/fate/ml/aggregator/test/test_fate_utils.py diff --git a/python/fate/arch/protocol/_dh.py b/python/fate/arch/protocol/_dh.py index 51391ea69c..b2769fcba4 100644 --- a/python/fate/arch/protocol/_dh.py +++ b/python/fate/arch/protocol/_dh.py @@ -44,7 +44,6 @@ def dh_exchange(self, ctx: Context, ranks: typing.List[int]): def secure_aggregate(self, ctx: Context, array: typing.List[numpy.ndarray], weight: typing.Optional[int] = None): mixed = self._get_mixer().mix(array, weight) - print(mixed) ctx.arbiter.put(self._get_name(self._send_name), (mixed, weight)) return ctx.arbiter.get(self._get_name(self._recv_name)) diff --git a/python/fate/ml/aggregator/test/test_fate_utils.py b/python/fate/ml/aggregator/test/test_fate_utils.py new file mode 100644 index 0000000000..4a3c2fab6b --- /dev/null +++ b/python/fate/ml/aggregator/test/test_fate_utils.py @@ -0,0 +1,52 @@ +import sys + +arbiter = ("arbiter", 10000) +guest = ("guest", 10000) +host = ("host", 9999) +name = "fed" + + +def create_ctx(local): + from fate.arch import Context + from fate.arch.computing.standalone import CSession + from fate.arch.federation.standalone import StandaloneFederation + import logging + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + computing = CSession() + return Context(computing=computing, + federation=StandaloneFederation(computing, name, local, [guest, host, arbiter])) + + +if __name__ == "__main__": + if sys.argv[1] == "guest": + from fate.arch.protocol import SecureAggregatorClient + import numpy as np + + ctx = create_ctx(guest) + client = SecureAggregatorClient() + client.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) + print('ranks are {}'.format([ctx.guest.rank, *ctx.hosts.ranks])) + print(client.secure_aggregate(ctx, [np.zeros((3, 4)), np.ones((2, 3))])) + elif sys.argv[1] == "host": + from fate.arch.protocol import SecureAggregatorClient + import numpy as np + + ctx = create_ctx(host) + client = SecureAggregatorClient() + client.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) + print(client.secure_aggregate(ctx, [np.zeros((3, 4)), np.ones((2, 3))])) + else: + from fate.arch.protocol import SecureAggregatorServer + + ctx = create_ctx(arbiter) + server = SecureAggregatorServer([ctx.guest.rank, *ctx.hosts.ranks]) + server.secure_aggregate(ctx) \ No newline at end of file From 58b505fa8aef954c3f1876207bfc73e98eeba83c Mon Sep 17 00:00:00 2001 From: weiwee Date: Fri, 14 Jul 2023 18:06:50 +0800 Subject: [PATCH 45/61] add mock Signed-off-by: weiwee --- python/fate/arch/protocol/_dh.py | 77 +++++++++++++++++++++++++------- 1 file changed, 60 insertions(+), 17 deletions(-) diff --git a/python/fate/arch/protocol/_dh.py b/python/fate/arch/protocol/_dh.py index b2769fcba4..b907e645ec 100644 --- a/python/fate/arch/protocol/_dh.py +++ b/python/fate/arch/protocol/_dh.py @@ -17,9 +17,16 @@ def _get_name(self, name): class SecureAggregatorClient(_SecureAggregatorMeta): - def __init__(self, prefix: typing.Optional[str] = None): + def __init__(self, prefix: typing.Optional[str] = None, is_mock: bool = False): + """ + secure aggregation client + Args: + prefix: unique prefix for this aggregator + is_mock: mock the aggregator, do not perform secure aggregation, for test only + """ self.prefix = prefix self._mixer = None + self._is_mock = is_mock def _get_mixer(self): if self._mixer is None: @@ -27,6 +34,8 @@ def _get_mixer(self): return self._mixer def dh_exchange(self, ctx: Context, ranks: typing.List[int]): + if self._is_mock: + return local_rank = ctx.local.rank dh = {} seeds = {} @@ -43,31 +52,65 @@ def dh_exchange(self, ctx: Context, ranks: typing.List[int]): self._mixer = RandomMix(seeds, local_rank) def secure_aggregate(self, ctx: Context, array: typing.List[numpy.ndarray], weight: typing.Optional[int] = None): - mixed = self._get_mixer().mix(array, weight) - ctx.arbiter.put(self._get_name(self._send_name), (mixed, weight)) - return ctx.arbiter.get(self._get_name(self._recv_name)) + if self._is_mock: + ctx.arbiter.put(self._get_name(self._send_name), (array, weight)) + return ctx.arbiter.get(self._get_name(self._recv_name)) + else: + mixed = self._get_mixer().mix(array, weight) + ctx.arbiter.put(self._get_name(self._send_name), (mixed, weight)) + return ctx.arbiter.get(self._get_name(self._recv_name)) class SecureAggregatorServer(_SecureAggregatorMeta): - def __init__(self, ranks, prefix: typing.Optional[str] = None): + def __init__(self, ranks, prefix: typing.Optional[str] = None, is_mock: bool = False): + """ + secure aggregation server + Args: + ranks: all ranks + prefix: unique prefix for this aggregator + is_mock: mock the aggregator, do not perform secure aggregation, for test only + """ self.prefix = prefix self.ranks = ranks + self._is_mock = is_mock def secure_aggregate(self, ctx: Context, ranks: typing.Optional[int] = None): - mix_aggregator = MixAggregate() + """ + perform secure aggregate once + Args: + ctx: Context to use + ranks: ranks to aggregate, if None, use all ranks + """ + if ranks is None: + ranks = self.ranks aggregated_weight = 0.0 has_weight = False - if ranks is None: - ranks = self.ranks - for rank in ranks: - mix_arrays, weight = ctx.parties[rank].get(self._get_name(self._send_name)) - mix_aggregator.aggregate(mix_arrays) - if weight is not None: - has_weight = True - aggregated_weight += weight - if not has_weight: - aggregated_weight = None - aggregated = mix_aggregator.finalize(aggregated_weight) + if self._is_mock: + aggregated = [] + for rank in ranks: + arrays, weight = ctx.parties[rank].get(self._get_name(self._send_name)) + for i in range(len(arrays)): + if len(aggregated) <= i: + aggregated.append(arrays[i]) + else: + aggregated[i] += arrays[i] + if weight is not None: + has_weight = True + aggregated_weight += weight + if has_weight: + aggregated = [x / aggregated_weight for x in aggregated] + else: + mix_aggregator = MixAggregate() + for rank in ranks: + mix_arrays, weight = ctx.parties[rank].get(self._get_name(self._send_name)) + mix_aggregator.aggregate(mix_arrays) + if weight is not None: + has_weight = True + aggregated_weight += weight + if not has_weight: + aggregated_weight = None + aggregated = mix_aggregator.finalize(aggregated_weight) + for rank in ranks: ctx.parties[rank].put(self._get_name(self._recv_name), aggregated) From 3e12a97899dcab047eb027b60b07df9ae4ca980d Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Fri, 14 Jul 2023 18:08:19 +0800 Subject: [PATCH 46/61] dataframe: support hist interface Signed-off-by: mgqa34 --- python/fate/arch/dataframe/_dataframe.py | 5 ++ python/fate/arch/dataframe/ops/_encoder.py | 10 +++ python/fate/arch/dataframe/ops/_histogram.py | 67 ++++++++++++++++++++ 3 files changed, 82 insertions(+) create mode 100644 python/fate/arch/dataframe/ops/_histogram.py diff --git a/python/fate/arch/dataframe/_dataframe.py b/python/fate/arch/dataframe/_dataframe.py index 5b27cada4b..e41927ef98 100644 --- a/python/fate/arch/dataframe/_dataframe.py +++ b/python/fate/arch/dataframe/_dataframe.py @@ -285,6 +285,11 @@ def bucketize(self, boundaries: Union[dict, pd.DataFrame]) -> "DataFrame": from .ops._encoder import bucketize return bucketize(self, boundaries) + def hist(self, targets): + from .ops._histogram import hist + + return hist(self, targets) + def __add__(self, other: Union[int, float, list, "np.ndarray", "DataFrame", "pd.Series"]) -> "DataFrame": return self.__arithmetic_operate(operator.add, other) diff --git a/python/fate/arch/dataframe/ops/_encoder.py b/python/fate/arch/dataframe/ops/_encoder.py index ebb0a70863..44693e34ae 100644 --- a/python/fate/arch/dataframe/ops/_encoder.py +++ b/python/fate/arch/dataframe/ops/_encoder.py @@ -20,6 +20,7 @@ import torch from sklearn.preprocessing import OneHotEncoder from typing import Union +from ._compress_block import compress_blocks from .._dataframe import DataFrame from ..manager import BlockType, DataManager @@ -166,6 +167,15 @@ def _mapper(blocks, boundaries_list: list = None, narrow_loc: list = None, block_table = df.block_table.mapValues(bucketize_mapper) + block_indexes = data_manager.infer_operable_blocks() + if len(block_indexes) > 1: + to_promote_types = [] + for bid in block_indexes: + to_promote_types.append((bid, BlockType.get_block_type(BUCKETIZE_RESULT_TYPE))) + + data_manager.promote_types(to_promote_types) + block_table, data_manager = compress_blocks(block_table, data_manager) + return DataFrame( df._ctx, block_table, diff --git a/python/fate/arch/dataframe/ops/_histogram.py b/python/fate/arch/dataframe/ops/_histogram.py new file mode 100644 index 0000000000..1c26eafa5e --- /dev/null +++ b/python/fate/arch/dataframe/ops/_histogram.py @@ -0,0 +1,67 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import functools + +from fate.arch.tensor.inside import Hist +from ._compress_block import compress_blocks +from .._dataframe import DataFrame +from ..manager import BlockType, DataManager + + +def hist(df: DataFrame, targets): + data_manager = df.data_manager + + block_table, data_manager = _try_to_compress_table(df.block_table, data_manager) + block_id = data_manager.infer_operable_blocks()[0] + + def _mapper(blocks, target, bid: int = None): + histogram = Hist() + histogram.update(blocks[bid], target) + + return histogram + + def _reducer(l_histogram, r_histogram): + return l_histogram.merge(r_histogram) + + _mapper_func = functools.partial(_mapper, bid=block_id) + + return block_table.join(targets.shardings._data, _mapper_func).reduce(_reducer) + + +def _try_to_compress_table(block_table, data_manager: DataManager): + block_indexes = data_manager.infer_operable_blocks() + if len(block_indexes) == 1: + return block_table, data_manager + + block_type = None + for block_id in block_indexes: + _type = data_manager.get_block(block_id).block_type + if not BlockType.is_integer(_type): + raise ValueError("To use hist interface, indexes type should be integer >= 0") + + if not block_type: + block_type = _type + elif block_type < _type: + block_type = _type + + to_promote_types = [] + for bid in block_indexes: + to_promote_types.append((bid, block_type)) + + data_manager.promote_types(to_promote_types) + block_table, data_manager = compress_blocks(block_table, data_manager) + + return block_table, data_manager From 6c32afcbf472f623620aed0883072a78c48d98db Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Fri, 14 Jul 2023 18:10:55 +0800 Subject: [PATCH 47/61] fix lr & linr load model & parameter update when warm start(#4659) Signed-off-by: Yu Wu --- examples/pipeline/test_linr_sid_warm_start.py | 81 +++++++++++++++++++ examples/pipeline/test_lr_sid_warm_start.py | 81 +++++++++++++++++++ .../components/components/coordinated_linr.py | 12 +-- .../components/components/coordinated_lr.py | 13 +-- .../ml/glm/hetero/coordinated_linr/arbiter.py | 16 +++- .../ml/glm/hetero/coordinated_linr/guest.py | 14 +++- .../ml/glm/hetero/coordinated_linr/host.py | 14 +++- .../ml/glm/hetero/coordinated_lr/arbiter.py | 18 ++++- .../ml/glm/hetero/coordinated_lr/guest.py | 18 ++++- .../fate/ml/glm/hetero/coordinated_lr/host.py | 18 ++++- 10 files changed, 264 insertions(+), 21 deletions(-) create mode 100644 examples/pipeline/test_linr_sid_warm_start.py create mode 100644 examples/pipeline/test_lr_sid_warm_start.py diff --git a/examples/pipeline/test_linr_sid_warm_start.py b/examples/pipeline/test_linr_sid_warm_start.py new file mode 100644 index 0000000000..14837e09a9 --- /dev/null +++ b/examples/pipeline/test_linr_sid_warm_start.py @@ -0,0 +1,81 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLinR, Intersection +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel + +pipeline = FateFlowPipeline().set_roles(guest="9999", host="9998", arbiter="9998") + +intersect_0 = Intersection("intersect_0", method="raw") +intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment_sid")) +intersect_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace="experiment_sid")) +linr_0 = CoordinatedLinR("linr_0", + epochs=3, + batch_size=None, + optimizer={"method": "sgd", "optimizer_params": {"lr": 0.01}}, + init_param={"fit_intercept": True, "method": "zeros"}, + train_data=intersect_0.outputs["output_data"]) +linr_1 = CoordinatedLinR("linr_1", train_data=intersect_0.outputs["output_data"], + warm_start_model=linr_0.outputs["output_model"], + epochs=2, + batch_size=200) + +"""linr_0.guest.component_setting(train_data=DataWarehouseChannel(name="breast_hetero_guest_sid", + namespace="experiment")) +linr_0.hosts[0].component_setting(train_data=DataWarehouseChannel(name="breast_hetero_host_sid", + namespace="experiment"))""" + +evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + input_data=linr_0.outputs["train_output_data"]) + +# pipeline.add_task(feature_scale_0) +# pipeline.add_task(feature_scale_1) +pipeline.add_task(intersect_0) +pipeline.add_task(linr_0) +pipeline.add_task(linr_1) +# pipeline.add_task(evaluation_0) +# pipeline.add_task(hetero_feature_binning_0) +pipeline.compile() +print(pipeline.get_dag()) +pipeline.fit() +print(f"linr_0 model: {pipeline.get_task_info('linr_0').get_output_model()}") +# print(f"linr_0 data: {pipeline.get_task_info('linr_0').get_output_data()}") +print(f"\nlinr_1 model: {pipeline.get_task_info('linr_1').get_output_model()}") + +"""# print(pipeline.get_task_info("statistics_0").get_output_model()) +print(pipeline.get_task_info("linr_0").get_output_model()) +print(pipeline.get_task_info("linr_0").get_output_metrics()) +print(f"evaluation metrics: ") +print(pipeline.get_task_info("evaluation_0").get_output_metrics()) + +pipeline.deploy([intersect_0, linr_0]) + +predict_pipeline = FateFlowPipeline() + +deployed_pipeline = pipeline.get_deployed_pipeline() +deployed_pipeline.intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment_sid")) +deployed_pipeline.intersect_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace="experiment_sid")) + +predict_pipeline.add_task(deployed_pipeline) +predict_pipeline.compile() +# print("\n\n\n") +# print(predict_pipeline.compile().get_dag()) +predict_pipeline.predict()""" diff --git a/examples/pipeline/test_lr_sid_warm_start.py b/examples/pipeline/test_lr_sid_warm_start.py new file mode 100644 index 0000000000..bbd548313d --- /dev/null +++ b/examples/pipeline/test_lr_sid_warm_start.py @@ -0,0 +1,81 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLR, Intersection +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel + +pipeline = FateFlowPipeline().set_roles(guest="9999", host="9998", arbiter="9998") + +intersect_0 = Intersection("intersect_0", method="raw") +intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment_sid")) +intersect_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace="experiment_sid")) +lr_0 = CoordinatedLR("lr_0", + epochs=3, + batch_size=None, + optimizer={"method": "sgd", "optimizer_params": {"lr": 0.01}}, + init_param={"fit_intercept": True, "method": "zeros"}, + train_data=intersect_0.outputs["output_data"]) +lr_1 = CoordinatedLR("lr_1", train_data=intersect_0.outputs["output_data"], + warm_start_model=lr_0.outputs["output_model"], + epochs=2, + batch_size=200) + +"""lr_0.guest.component_setting(train_data=DataWarehouseChannel(name="breast_hetero_guest_sid", + namespace="experiment")) +lr_0.hosts[0].component_setting(train_data=DataWarehouseChannel(name="breast_hetero_host_sid", + namespace="experiment"))""" + +evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + input_data=lr_0.outputs["train_output_data"]) + +# pipeline.add_task(feature_scale_0) +# pipeline.add_task(feature_scale_1) +pipeline.add_task(intersect_0) +pipeline.add_task(lr_0) +pipeline.add_task(lr_1) +# pipeline.add_task(evaluation_0) +# pipeline.add_task(hetero_feature_binning_0) +pipeline.compile() +print(pipeline.get_dag()) +pipeline.fit() +print(f"lr_0 model: {pipeline.get_task_info('lr_0').get_output_model()}") +# print(f"lr_0 data: {pipeline.get_task_info('lr_0').get_output_data()}") +print(f"\nlr_1 model: {pipeline.get_task_info('lr_1').get_output_model()}") + +"""# print(pipeline.get_task_info("statistics_0").get_output_model()) +print(pipeline.get_task_info("lr_0").get_output_model()) +print(pipeline.get_task_info("lr_0").get_output_metrics()) +print(f"evaluation metrics: ") +print(pipeline.get_task_info("evaluation_0").get_output_metrics()) + +pipeline.deploy([intersect_0, lr_0]) + +predict_pipeline = FateFlowPipeline() + +deployed_pipeline = pipeline.get_deployed_pipeline() +deployed_pipeline.intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment_sid")) +deployed_pipeline.intersect_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace="experiment_sid")) + +predict_pipeline.add_task(deployed_pipeline) +predict_pipeline.compile() +# print("\n\n\n") +# print(predict_pipeline.compile().get_dag()) +predict_pipeline.predict()""" diff --git a/python/fate/components/components/coordinated_linr.py b/python/fate/components/components/coordinated_linr.py index 9741c0ae1d..118520a5c0 100644 --- a/python/fate/components/components/coordinated_linr.py +++ b/python/fate/components/components/coordinated_linr.py @@ -212,8 +212,8 @@ def train_guest(ctx, train_data, validate_data, train_output_data, output_model, logger.info(f"warm start model provided") model = input_model.read() module = CoordinatedLinRModuleGuest.from_model(model) - module.epochs = epochs - module.batch_size = batch_size + module.set_epochs(epochs) + module.set_batch_size(batch_size) else: module = CoordinatedLinRModuleGuest(epochs=epochs, batch_size=batch_size, optimizer_param=optimizer_param, learning_rate_param=learning_rate_param, @@ -248,8 +248,8 @@ def train_host(ctx, train_data, validate_data, train_output_data, output_model, logger.info(f"warm start model provided") model = input_model.read() module = CoordinatedLinRModuleHost.from_model(model) - module.epochs = epochs - module.batch_size = batch_size + module.set_epochs(epochs) + module.set_batch_size(batch_size) else: module = CoordinatedLinRModuleHost(epochs=epochs, batch_size=batch_size, optimizer_param=optimizer_param, learning_rate_param=learning_rate_param, @@ -278,8 +278,8 @@ def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, logger.info(f"warm start model provided") model = input_model.read() module = CoordinatedLinRModuleArbiter.from_model(model) - module.epochs = epochs - module.batch_size = batch_size + module.set_epochs(epochs) + module.set_batch_size(batch_size) else: module = CoordinatedLinRModuleArbiter(epochs=epochs, early_stop=early_stop, tol=tol, batch_size=batch_size, optimizer_param=optimizer_param, learning_rate_param=learning_rate_param, diff --git a/python/fate/components/components/coordinated_lr.py b/python/fate/components/components/coordinated_lr.py index cd4bf473e2..03de669234 100644 --- a/python/fate/components/components/coordinated_lr.py +++ b/python/fate/components/components/coordinated_lr.py @@ -267,8 +267,9 @@ def train_guest( logger.info(f"warm start model provided") model = input_model.read() module = CoordinatedLRModuleGuest.from_model(model) - module.epochs = epochs - module.batch_size = batch_size + module.set_epochs(epochs) + module.set_batch_size(batch_size) + else: module = CoordinatedLRModuleGuest( epochs=epochs, @@ -329,8 +330,8 @@ def train_host( logger.info(f"warm start model provided") model = input_model.read() module = CoordinatedLRModuleHost.from_model(model) - module.epochs = epochs - module.batch_size = batch_size + module.set_epochs(epochs) + module.set_batch_size(batch_size) else: module = CoordinatedLRModuleHost( epochs=epochs, @@ -363,8 +364,8 @@ def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, lea logger.info(f"warm start model provided") model = input_model.read() module = CoordinatedLRModuleArbiter.from_model(model) - module.epochs = epochs - module.batch_size = batch_size + module.set_epochs(epochs) + module.set_batch_size(batch_size) else: module = CoordinatedLRModuleArbiter( epochs=epochs, diff --git a/python/fate/ml/glm/hetero/coordinated_linr/arbiter.py b/python/fate/ml/glm/hetero/coordinated_linr/arbiter.py index a4ae6e40d3..6d0bf8f178 100644 --- a/python/fate/ml/glm/hetero/coordinated_linr/arbiter.py +++ b/python/fate/ml/glm/hetero/coordinated_linr/arbiter.py @@ -45,6 +45,14 @@ def __init__( self.estimator = None + def set_batch_size(self, batch_size): + self.batch_size = batch_size + self.estimator.batch_size = batch_size + + def set_epochs(self, epochs): + self.epochs = epochs + self.estimator.epochs = epochs + def fit(self, ctx: Context) -> None: encryptor, decryptor = ctx.cipher.phe.keygen(options=dict(key_length=2048)) ctx.hosts("encryptor").put(encryptor) @@ -126,7 +134,7 @@ def fit_model(self, ctx, decryptor): optimizer_ready = True # self.start_epoch = self.end_epoch + 1 - for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): + for i, iter_ctx in ctx.on_iterations.ctxs_range(self.epochs): iter_loss = None iter_g = None self.optimizer.set_iters(i) @@ -197,7 +205,9 @@ def get_model(self): "optimizer": self.optimizer.state_dict(), "lr_scheduler": self.lr_scheduler.state_dict(), "end_epoch": self.end_epoch, - "converged": self.is_converged + "is_converged": self.is_converged, + "tol": self.tol, + "early_stop": self.early_stop } def restore(self, model): @@ -207,5 +217,7 @@ def restore(self, model): self.lr_scheduler.load_state_dict(model["lr_scheduler"], self.optimizer.optimizer) self.end_epoch = model["end_epoch"] self.is_converged = model["is_converged"] + self.tol = model["tol"] + self.early_stop = model["early_stop"] self.converge_func = converge_func_factory(self.early_stop, self.tol) # self.start_epoch = model["end_epoch"] + 1 diff --git a/python/fate/ml/glm/hetero/coordinated_linr/guest.py b/python/fate/ml/glm/hetero/coordinated_linr/guest.py index de46955f3f..f1c8f1f3cb 100644 --- a/python/fate/ml/glm/hetero/coordinated_linr/guest.py +++ b/python/fate/ml/glm/hetero/coordinated_linr/guest.py @@ -42,6 +42,14 @@ def __init__( self.estimator = None + def set_batch_size(self, batch_size): + self.batch_size = batch_size + self.estimator.batch_size = batch_size + + def set_epochs(self, epochs): + self.epochs = epochs + self.estimator.epochs = epochs + def fit(self, ctx: Context, train_data, validate_data=None) -> None: if self.estimator is None: optimizer = Optimizer( @@ -79,7 +87,9 @@ def from_model(cls, model) -> "CoordinatedLinRModuleGuest": learning_rate_param=model["meta"]["learning_rate_param"], batch_size=model["meta"]["batch_size"], init_param=model["meta"]["init_param"]) - estimator = CoordinatedLinREstimatorGuest() + estimator = CoordinatedLinREstimatorGuest(epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"]) estimator.restore(model["data"]["estimator"]) linr.estimator = estimator @@ -123,7 +133,7 @@ def fit_model(self, ctx, train_data, validate_data=None): # if self.end_epoch >= 0: # self.start_epoch = self.end_epoch + 1 - for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): + for i, iter_ctx in ctx.on_iterations.ctxs_range(self.epochs): self.optimizer.set_iters(i) logger.info(f"self.optimizer set epoch {i}") for batch_ctx, batch_data in iter_ctx.on_batches.ctxs_zip(batch_loader): diff --git a/python/fate/ml/glm/hetero/coordinated_linr/host.py b/python/fate/ml/glm/hetero/coordinated_linr/host.py index 30c4f69480..59b69401fa 100644 --- a/python/fate/ml/glm/hetero/coordinated_linr/host.py +++ b/python/fate/ml/glm/hetero/coordinated_linr/host.py @@ -43,6 +43,14 @@ def __init__( self.estimator = None + def set_batch_size(self, batch_size): + self.batch_size = batch_size + self.estimator.batch_size = batch_size + + def set_epochs(self, epochs): + self.epochs = epochs + self.estimator.epochs = epochs + def fit(self, ctx: Context, train_data, validate_data=None) -> None: encryptor = ctx.arbiter("encryptor").get() if self.estimator is None: @@ -82,7 +90,9 @@ def from_model(cls, model) -> "CoordinatedLinRModuleHost": epochs=model["meta"]["epochs"], batch_size=model["meta"]["batch_size"], init_param=model["meta"]["init_param"]) - estimator = CoordiantedLinREstimatorHost() + estimator = CoordiantedLinREstimatorHost(epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"]) estimator.restore(model["data"]["estimator"]) linr.estimator = estimator @@ -120,7 +130,7 @@ def fit_model(self, ctx: Context, encryptor, train_data, validate_data=None) -> self.lr_scheduler.init_scheduler(optimizer=self.optimizer.optimizer) # if self.end_epoch >= 0: # self.start_epoch = self.end_epoch + 1 - for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): + for i, iter_ctx in ctx.on_iterations.ctxs_range(self.epochs): self.optimizer.set_iters(i) logger.info(f"self.optimizer set epoch {i}") for batch_ctx, batch_data in iter_ctx.on_batches.ctxs_zip(batch_loader): diff --git a/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py b/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py index 1ad6aa5bce..ffc3694bf2 100644 --- a/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py @@ -38,6 +38,22 @@ def __init__(self, epochs, early_stop, tol, batch_size, optimizer_param, learnin self.estimator = None self.ovr = False + def set_batch_size(self, batch_size): + self.batch_size = batch_size + if self.ovr: + for estimator in self.estimator.values(): + estimator.batch_size = batch_size + else: + self.estimator.batch_size = batch_size + + def set_epochs(self, epochs): + self.epochs = epochs + if self.ovr: + for estimator in self.estimator.values(): + estimator.epochs = epochs + else: + self.estimator.epochs = epochs + def fit(self, ctx: Context) -> None: encryptor, decryptor = ctx.cipher.phe.keygen(options=dict(key_length=2048)) ctx.hosts("encryptor").put(encryptor) @@ -169,7 +185,7 @@ def fit_single_model(self, ctx: Context, decryptor): else: optimizer_ready = True # self.start_epoch = self.end_epoch + 1 - for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): + for i, iter_ctx in ctx.on_iterations.ctxs_range(self.epochs): iter_loss = None iter_g = None self.optimizer.set_iters(i) diff --git a/python/fate/ml/glm/hetero/coordinated_lr/guest.py b/python/fate/ml/glm/hetero/coordinated_lr/guest.py index ddeda64594..f410992a71 100644 --- a/python/fate/ml/glm/hetero/coordinated_lr/guest.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/guest.py @@ -46,6 +46,22 @@ def __init__( self.ovr = False self.labels = None + def set_batch_size(self, batch_size): + self.batch_size = batch_size + if self.ovr: + for estimator in self.estimator.values(): + estimator.batch_size = batch_size + else: + self.estimator.batch_size = batch_size + + def set_epochs(self, epochs): + self.epochs = epochs + if self.ovr: + for estimator in self.estimator.values(): + estimator.epochs = epochs + else: + self.estimator.epochs = epochs + def fit(self, ctx: Context, train_data, validate_data=None) -> None: original_label = train_data.label train_data_binarized_label = train_data.label.get_dummies() @@ -214,7 +230,7 @@ def fit_single_model(self, ctx: Context, train_data, validate_data=None): # if self.end_epoch >= 0: # self.start_epoch = self.end_epoch + 1 - for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): + for i, iter_ctx in ctx.on_iterations.ctxs_range(self.epochs): self.optimizer.set_iters(i) logger.info(f"self.optimizer set epoch {i}") for batch_ctx, batch_data in iter_ctx.on_batches.ctxs_zip(batch_loader): diff --git a/python/fate/ml/glm/hetero/coordinated_lr/host.py b/python/fate/ml/glm/hetero/coordinated_lr/host.py index cad6e7cec5..5b8642026f 100644 --- a/python/fate/ml/glm/hetero/coordinated_lr/host.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/host.py @@ -42,6 +42,22 @@ def __init__( self.ovr = False self.label_count = False + def set_batch_size(self, batch_size): + self.batch_size = batch_size + if self.ovr: + for estimator in self.estimator.values(): + estimator.batch_size = batch_size + else: + self.estimator.batch_size = batch_size + + def set_epochs(self, epochs): + self.epochs = epochs + if self.ovr: + for estimator in self.estimator.values(): + estimator.epochs = epochs + else: + self.estimator.epochs = epochs + def fit(self, ctx: Context, train_data, validate_data=None) -> None: encryptor = ctx.arbiter("encryptor").get() label_count = ctx.guest("label_count").get() @@ -173,7 +189,7 @@ def fit_single_model(self, ctx: Context, encryptor, train_data, validate_data=No batch_loader = DataLoader(train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="host") # if self.end_epoch >= 0: # self.start_epoch = self.end_epoch + 1 - for i, iter_ctx in ctx.on_iterations.ctxs_range(self.start_epoch, self.epochs): + for i, iter_ctx in ctx.on_iterations.ctxs_range(self.epochs): self.optimizer.set_iters(i) logger.info(f"self.optimizer set epoch{i}") for batch_ctx, batch_data in iter_ctx.on_batches.ctxs_zip(batch_loader): From 7acbfd981fa56b9e9de211bde2d74ee819d5a701 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Mon, 17 Jul 2023 16:13:34 +0800 Subject: [PATCH 48/61] dataframe: add copy interface Signed-off-by: mgqa34 --- python/fate/arch/dataframe/_dataframe.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/fate/arch/dataframe/_dataframe.py b/python/fate/arch/dataframe/_dataframe.py index e41927ef98..47427a5f87 100644 --- a/python/fate/arch/dataframe/_dataframe.py +++ b/python/fate/arch/dataframe/_dataframe.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import copy import operator from typing import List, Union @@ -537,6 +538,14 @@ def _merge_list(lhs, rhs): def iloc(self, indexes): ... + def copy(self) -> "DataFrame": + return DataFrame( + self._ctx, + self._block_table.mapValues(lambda v: v), + copy.deepcopy(self.partition_order_mappings), + self._data_manager.duplicate() + ) + @classmethod def hstack(cls, stacks: List["DataFrame"]) -> "DataFrame": from .ops._dimension_scaling import hstack From 9e65e162a94adcce488ff4367f171e9331ec1fbf Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Mon, 17 Jul 2023 16:46:27 +0800 Subject: [PATCH 49/61] fix lr & linr export model param & ovr training(#4659) lr & linr use predict utils(#4659) Signed-off-by: Yu Wu --- .../components/components/coordinated_linr.py | 38 ++++++++++------- .../components/components/coordinated_lr.py | 41 +++++++++++-------- .../ml/glm/hetero/coordinated_linr/guest.py | 6 ++- .../ml/glm/hetero/coordinated_lr/guest.py | 29 ++++++++++--- .../fate/ml/glm/hetero/coordinated_lr/host.py | 2 +- python/fate/ml/utils/_model_param.py | 2 +- 6 files changed, 76 insertions(+), 42 deletions(-) diff --git a/python/fate/components/components/coordinated_linr.py b/python/fate/components/components/coordinated_linr.py index 118520a5c0..7d46b6f4c5 100644 --- a/python/fate/components/components/coordinated_linr.py +++ b/python/fate/components/components/coordinated_linr.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging from fate.arch import Context from fate.arch.dataframe import DataFrame +from fate.components.components.utils import consts, tools from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params from fate.ml.glm import CoordinatedLinRModuleArbiter, CoordinatedLinRModuleGuest, CoordinatedLinRModuleHost @@ -176,15 +176,18 @@ def cross_validation( module.fit(fold_ctx, train_data, validate_data) if output_cv_data: sub_ctx = fold_ctx.sub_ctx("predict_train") - predict_score = module.predict(sub_ctx, train_data) - train_predict_result = transform_to_predict_result( + train_predict_df = module.predict(sub_ctx, train_data) + """train_predict_result = transform_to_predict_result( train_data, predict_score, data_type="train" - ) + )""" + train_predict_result = tools.add_dataset_type(train_predict_df, consts.TRAIN_SET) sub_ctx = fold_ctx.sub_ctx("predict_validate") - predict_score = module.predict(sub_ctx, validate_data) - validate_predict_result = transform_to_predict_result( + validate_predict_df = module.predict(sub_ctx, validate_data) + validate_predict_result = tools.add_dataset_type(validate_predict_df, consts.VALIDATE_SET) + """validate_predict_result = transform_to_predict_result( validate_data, predict_score, data_type="predict" ) + """ predict_result = DataFrame.vstack([train_predict_result, validate_predict_result]) next(cv_output_datas).write(df=predict_result) @@ -230,14 +233,18 @@ def train_guest(ctx, train_data, validate_data, train_output_data, output_model, sub_ctx = ctx.sub_ctx("predict") - predict_score = module.predict(sub_ctx, train_data) - predict_result = transform_to_predict_result(train_data, predict_score, - data_type="train") + predict_df = module.predict(sub_ctx, train_data) + """predict_result = transform_to_predict_result(train_data, predict_score, + data_type="train")""" + predict_result = tools.add_dataset_type(predict_df, consts.TRAIN_SET) if validate_data is not None: sub_ctx = ctx.sub_ctx("validate_predict") - predict_score = module.predict(sub_ctx, validate_data) - validate_predict_result = transform_to_predict_result(validate_data, predict_score, + predict_df = module.predict(sub_ctx, validate_data) + validate_predict_result = tools.add_dataset_type(predict_df, consts.VALIDATE_SET) + + """validate_predict_result = transform_to_predict_result(validate_data, predict_score, data_type="validate") + """ predict_result = DataFrame.vstack([predict_result, validate_predict_result]) train_output_data.write(predict_result) @@ -300,8 +307,9 @@ def predict_guest(ctx, input_model, test_data, test_output_data): module = CoordinatedLinRModuleGuest.from_model(model) test_data = test_data.read() - predict_score = module.predict(sub_ctx, test_data) - predict_result = transform_to_predict_result(test_data, predict_score, data_type="predict") + predict_result = module.predict(sub_ctx, test_data) + predict_result = tools.add_dataset_type(predict_result, consts.TEST_SET) + # predict_result = transform_to_predict_result(test_data, predict_score, data_type="predict") test_output_data.write(predict_result) @@ -313,7 +321,7 @@ def predict_host(ctx, input_model, test_data, test_output_data): module.predict(sub_ctx, test_data) -def transform_to_predict_result(test_data, predict_score, data_type="test"): +"""def transform_to_predict_result(test_data, predict_score, data_type="test"): df = test_data.create_frame(with_label=True, with_weight=False) pred_res = test_data.create_frame(with_label=False, with_weight=False) pred_res["predict_result"] = predict_score @@ -322,4 +330,4 @@ def transform_to_predict_result(test_data, predict_score, data_type="test"): v[0], json.dumps({"label": v[0]}), data_type], enable_type_align_checking=False) - return df + return df""" diff --git a/python/fate/components/components/coordinated_lr.py b/python/fate/components/components/coordinated_lr.py index 03de669234..46a33e00fe 100644 --- a/python/fate/components/components/coordinated_lr.py +++ b/python/fate/components/components/coordinated_lr.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging from fate.arch import Context from fate.arch.dataframe import DataFrame +from fate.components.components.utils import consts, tools from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params from fate.ml.glm import CoordinatedLRModuleGuest, CoordinatedLRModuleHost, CoordinatedLRModuleArbiter @@ -217,17 +217,19 @@ def cross_validation( module.fit(fold_ctx, train_data, validate_data) if output_cv_data: sub_ctx = fold_ctx.sub_ctx("predict_train") - predict_score = module.predict(sub_ctx, train_data) - train_predict_result = transform_to_predict_result( + predict_df = module.predict(sub_ctx, train_data) + """train_predict_result = transform_to_predict_result( train_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, data_type="train" - ) + )""" + train_predict_result = tools.add_dataset_type(predict_df, consts.TRAIN_SET) sub_ctx = fold_ctx.sub_ctx("predict_validate") - predict_score = module.predict(sub_ctx, validate_data) - validate_predict_result = transform_to_predict_result( + predict_df = module.predict(sub_ctx, validate_data) + """validate_predict_result = transform_to_predict_result( validate_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, data_type="predict" - ) + )""" + validate_predict_result = tools.add_dataset_type(predict_df, consts.VALIDATE_SET) predict_result = DataFrame.vstack([train_predict_result, validate_predict_result]) next(cv_output_datas).write(df=predict_result) @@ -294,21 +296,23 @@ def train_guest( sub_ctx = ctx.sub_ctx("predict") - predict_score = module.predict(sub_ctx, train_data) - predict_result = transform_to_predict_result( + predict_df = module.predict(sub_ctx, train_data) + """predict_result = transform_to_predict_result( train_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, data_type="train" - ) + )""" + predict_result = tools.add_dataset_type(predict_df, consts.TRAIN_SET) if validate_data is not None: sub_ctx = ctx.sub_ctx("validate_predict") - predict_score = module.predict(sub_ctx, validate_data) - validate_predict_result = transform_to_predict_result( + predict_df = module.predict(sub_ctx, validate_data) + """validate_predict_result = transform_to_predict_result( validate_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, data_type="validate", - ) + )""" + validate_predict_result = tools.add_dataset_type(predict_df, consts.VALIDATE_SET) predict_result = DataFrame.vstack([predict_result, validate_predict_result]) train_output_data.write(predict_result) @@ -390,10 +394,11 @@ def predict_guest(ctx, input_model, test_data, test_output_data): # if module.threshold != 0.5: # module.threshold = threshold test_data = test_data.read() - predict_score = module.predict(sub_ctx, test_data) - predict_result = transform_to_predict_result( + predict_df = module.predict(sub_ctx, test_data) + """predict_result = transform_to_predict_result( test_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, data_type="test" - ) + )""" + predict_result = tools.add_dataset_type(predict_df, consts.TEST_SET) test_output_data.write(predict_result) @@ -406,7 +411,7 @@ def predict_host(ctx, input_model, test_data, test_output_data): module.predict(sub_ctx, test_data) -def transform_to_predict_result(test_data, predict_score, labels, threshold=0.5, is_ovr=False, data_type="test"): +"""def transform_to_predict_result(test_data, predict_score, labels, threshold=0.5, is_ovr=False, data_type="test"): if is_ovr: df = test_data.create_frame(with_label=True, with_weight=False) df[["predict_result", "predict_score", "predict_detail", "type"]] = predict_score.apply_row( @@ -424,4 +429,4 @@ def transform_to_predict_result(test_data, predict_score, labels, threshold=0.5, lambda v: [int(v[0] > threshold), v[0], json.dumps({1: v[0], 0: 1 - v[0]}), data_type], enable_type_align_checking=False, ) - return df + return df""" diff --git a/python/fate/ml/glm/hetero/coordinated_linr/guest.py b/python/fate/ml/glm/hetero/coordinated_linr/guest.py index f1c8f1f3cb..ba490b39cd 100644 --- a/python/fate/ml/glm/hetero/coordinated_linr/guest.py +++ b/python/fate/ml/glm/hetero/coordinated_linr/guest.py @@ -19,6 +19,7 @@ from fate.arch import dataframe, Context from fate.ml.abc.module import HeteroModule +from fate.ml.utils import predict_tools from fate.ml.utils._model_param import initialize_param, serialize_param, deserialize_param from fate.ml.utils._optimizer import Optimizer, LRScheduler @@ -192,7 +193,10 @@ def predict(self, ctx, test_data): pred = torch.matmul(X, self.w) for h_pred in ctx.hosts.get("h_pred"): pred += h_pred - return pred + pred_df = test_data.create_frame(with_label=True, with_weight=False) + pred_df[predict_tools.PREDICT_SCORE] = pred + predict_result = predict_tools.compute_predict_details(pred_df, task_type=predict_tools.REGRESSION) + return predict_result def get_model(self): """w = self.w.tolist() diff --git a/python/fate/ml/glm/hetero/coordinated_lr/guest.py b/python/fate/ml/glm/hetero/coordinated_lr/guest.py index f410992a71..99e66a2ea9 100644 --- a/python/fate/ml/glm/hetero/coordinated_lr/guest.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/guest.py @@ -19,6 +19,7 @@ from fate.arch import Context, dataframe from fate.ml.abc.module import HeteroModule +from fate.ml.utils import predict_tools from fate.ml.utils._model_param import initialize_param, serialize_param, deserialize_param from fate.ml.utils._optimizer import LRScheduler, Optimizer @@ -104,8 +105,12 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: single_estimator = self.estimator[i] single_estimator.epochs = self.epochs single_estimator.batch_size = self.batch_size - train_data.label = train_data_binarized_label[train_data_binarized_label.columns[i]] - single_estimator.fit_single_model(class_ctx, train_data, validate_data) + class_train_data = train_data.copy() + class_validate_data = validate_data + if validate_data: + class_validate_data = validate_data.copy() + class_train_data.label = train_data_binarized_label[train_data_binarized_label.columns[i]] + single_estimator.fit_single_model(class_ctx, class_train_data, class_validate_data) self.estimator[i] = single_estimator else: @@ -136,14 +141,26 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None: def predict(self, ctx, test_data): if self.ovr: - predict_score = test_data.create_frame(with_label=False, with_weight=False) + pred_score = test_data.create_frame(with_label=False, with_weight=False) for i, class_ctx in ctx.sub_ctx("class").ctxs_range(len(self.labels)): estimator = self.estimator[i] pred = estimator.predict(class_ctx, test_data) - predict_score[self.labels[i]] = pred + pred_score[self.labels[i]] = pred + pred_df = test_data.create_frame(with_label=True, with_weight=False) + pred_df[predict_tools.PREDICT_SCORE] = pred_score.apply_row(lambda v: [list(v)]) + predict_result = predict_tools.compute_predict_details(pred_df, + task_type=predict_tools.MULTI, + classes=self.labels) else: predict_score = self.estimator.predict(ctx, test_data) - return predict_score + pred_df = test_data.create_frame(with_label=True, with_weight=False) + pred_df[predict_tools.PREDICT_SCORE] = predict_score + predict_result = predict_tools.compute_predict_details(pred_df, + task_type=predict_tools.BINARY, + classes=self.labels, + threshold=self.threshold) + + return predict_result def get_model(self): all_estimator = {} @@ -292,7 +309,7 @@ def predict(self, ctx, test_data): if self.init_param.get("fit_intercept"): test_data["intercept"] = 1.0 X = test_data.values.as_tensor() - logger.info(f"in predict, w: {self.w}") + # logger.info(f"in predict, w: {self.w}") pred = torch.matmul(X, self.w) for h_pred in ctx.hosts.get("h_pred"): pred += h_pred diff --git a/python/fate/ml/glm/hetero/coordinated_lr/host.py b/python/fate/ml/glm/hetero/coordinated_lr/host.py index 5b8642026f..c315b6b8c7 100644 --- a/python/fate/ml/glm/hetero/coordinated_lr/host.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/host.py @@ -61,7 +61,7 @@ def set_epochs(self, epochs): def fit(self, ctx: Context, train_data, validate_data=None) -> None: encryptor = ctx.arbiter("encryptor").get() label_count = ctx.guest("label_count").get() - if self.label_count > 2 or self.ovr: + if label_count > 2 or self.ovr: self.ovr = True self.label_count = label_count warm_start = True diff --git a/python/fate/ml/utils/_model_param.py b/python/fate/ml/utils/_model_param.py index 175ec3d603..c79f90d287 100644 --- a/python/fate/ml/utils/_model_param.py +++ b/python/fate/ml/utils/_model_param.py @@ -41,8 +41,8 @@ def serialize_param(param, fit_intercept=False): w = param.tolist() intercept = None if fit_intercept: - w = w[:-1] intercept = w[-1] + w = w[:-1] return {"coef_": w, "intercept_": intercept, "dtype": dtype} From cbfff35659941b047a57fe89e54ca647e7c1e2d8 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Mon, 17 Jul 2023 17:12:06 +0800 Subject: [PATCH 50/61] fix lr multi load model(#4659) Signed-off-by: Yu Wu --- .../fate/ml/glm/hetero/coordinated_lr/arbiter.py | 8 ++++++-- python/fate/ml/glm/hetero/coordinated_lr/guest.py | 11 +++++++---- python/fate/ml/glm/hetero/coordinated_lr/host.py | 14 ++++++++++++-- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py b/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py index ffc3694bf2..538563536b 100644 --- a/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py @@ -148,9 +148,13 @@ def from_model(cls, model) -> "CoordinatedLRModuleArbiter": learning_rate_param=model["meta"]["learning_rate_param"], ) all_estimator = model["data"]["estimator"] - + lr.estimator = {} if lr.ovr: - lr.estimator = {label: CoordinatedLREstimatorArbiter().restore(d) for label, d in all_estimator.items()} + for label, d in all_estimator.items(): + estimator = CoordinatedLREstimatorArbiter(epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"]) + estimator.restore(d) + lr.estimator[int(label)] = estimator else: estimator = CoordinatedLREstimatorArbiter() estimator.restore(all_estimator) diff --git a/python/fate/ml/glm/hetero/coordinated_lr/guest.py b/python/fate/ml/glm/hetero/coordinated_lr/guest.py index 99e66a2ea9..9a95aa8730 100644 --- a/python/fate/ml/glm/hetero/coordinated_lr/guest.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/guest.py @@ -192,11 +192,14 @@ def from_model(cls, model) -> "CoordinatedLRModuleGuest": lr.labels = model["meta"]["labels"] all_estimator = model["data"]["estimator"] + lr.estimator = {} if lr.ovr: - lr.estimator = {label: CoordinatedLREstimatorGuest(epochs=model["meta"]["epochs"], - batch_size=model["meta"]["batch_size"], - init_param=model["meta"]["init_param"]). \ - restore(d) for label, d in all_estimator.items()} + for label, d in all_estimator.items(): + estimator = CoordinatedLREstimatorGuest(epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"]) + estimator.restore(d) + lr.estimator[int(label)] = estimator else: estimator = CoordinatedLREstimatorGuest(epochs=model["meta"]["epochs"], batch_size=model["meta"]["batch_size"], diff --git a/python/fate/ml/glm/hetero/coordinated_lr/host.py b/python/fate/ml/glm/hetero/coordinated_lr/host.py index c315b6b8c7..933d816675 100644 --- a/python/fate/ml/glm/hetero/coordinated_lr/host.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/host.py @@ -155,10 +155,20 @@ def from_model(cls, model) -> "CoordinatedLRModuleHost": lr.ovr = model["meta"]["ovr"] all_estimator = model["data"]["estimator"] + lr.estimator = {} + if lr.ovr: - lr.estimator = {label: CoordinatedLREstimatorHost().restore(d) for label, d in all_estimator.items()} + lr.estimator = {} + for label, d in all_estimator.items(): + estimator = CoordinatedLREstimatorHost(epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"]) + estimator.restore(d) + lr.estimator[int(label)] = estimator else: - estimator = CoordinatedLREstimatorHost() + estimator = CoordinatedLREstimatorHost(epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"]) estimator.restore(all_estimator) lr.estimator = estimator logger.info(f"finish from model") From 2f09d74d00b884325d35a6b755002a82e48577f0 Mon Sep 17 00:00:00 2001 From: weiwee Date: Tue, 18 Jul 2023 15:20:15 +0800 Subject: [PATCH 51/61] add sub to hist Signed-off-by: weiwee --- python/fate/arch/tensor/inside/_op_hist.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/python/fate/arch/tensor/inside/_op_hist.py b/python/fate/arch/tensor/inside/_op_hist.py index a2ddb7823f..7bd4a03c7f 100644 --- a/python/fate/arch/tensor/inside/_op_hist.py +++ b/python/fate/arch/tensor/inside/_op_hist.py @@ -1,6 +1,9 @@ +import typing + + class Hist: def __init__(self): - self.data = {} + self.data: typing.Dict[int, typing.Dict[int, typing.Any]] = {} def update(self, features, labels): shape_x, shape_y = features.shape @@ -34,6 +37,17 @@ def cumsum(self): self.data[k][kk] = s return self + def __sub__(self, other: "Hist"): + out = Hist() + for j in self.data: + out.data[j] = {} + for v in self.data[j]: + if v not in other.data[j]: + out.data[j][v] = self.data[j][v] + else: + out.data[j][v] = self.data[j][v] - other.data[j][v] + return out + if __name__ == "__main__": import numpy as np @@ -42,4 +56,4 @@ def cumsum(self): features = np.array([[1, 0], [0, 1], [2, 1], [2, 0]]) labels = np.array([0, 1, 0, 0]) hist.update(features, labels) - print(hist.data) + print((hist - hist).data) From 6574e65c9f936528286e990bd3d6e4cc14fbfa46 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Tue, 18 Jul 2023 16:49:27 +0800 Subject: [PATCH 52/61] fix optimizer (#4659) Signed-off-by: Yu Wu --- python/fate/ml/utils/_optimizer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/fate/ml/utils/_optimizer.py b/python/fate/ml/utils/_optimizer.py index 0ac158d64d..5e584815b1 100644 --- a/python/fate/ml/utils/_optimizer.py +++ b/python/fate/ml/utils/_optimizer.py @@ -305,6 +305,8 @@ def optimizer_factory(model_parameter, optimizer_type, optim_params): return torch.optim.RAdam(model_parameter, **optimizer_params) elif optimizer_type == 'rmsprop': return torch.optim.RMSprop(model_parameter, **optimizer_params) + elif optimizer_type == "rprop": + return torch.optim.Rprop(model_parameter, **optimizer_params) elif optimizer_type == 'sgd': return torch.optim.SGD(model_parameter, **optimizer_params) else: From 4fad5b832798a903874f8ff2fd582119e707388a Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Tue, 18 Jul 2023 16:50:42 +0800 Subject: [PATCH 53/61] edit examples (#4659) Signed-off-by: Yu Wu --- examples/pipeline/test_lr_sid.py | 25 +++++++++++++---------- examples/pipeline/test_single_linr.py | 2 +- examples/pipeline/test_single_lr_multi.py | 2 +- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/examples/pipeline/test_lr_sid.py b/examples/pipeline/test_lr_sid.py index 3a61bd48ca..b5fba7ae8f 100644 --- a/examples/pipeline/test_lr_sid.py +++ b/examples/pipeline/test_lr_sid.py @@ -20,15 +20,15 @@ pipeline = FateFlowPipeline().set_roles(guest="9999", host="9998", arbiter="9998") intersect_0 = Intersection("intersect_0", method="raw") -intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", +intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest_multi", namespace="experiment_sid")) intersect_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", namespace="experiment_sid")) lr_0 = CoordinatedLR("lr_0", - epochs=2, - batch_size=100, - optimizer={"method": "sgd", "optimizer_params": {"lr": 0.01}}, - init_param={"fit_intercept": True}, + epochs=4, + batch_size=None, + optimizer={"method": "rprop", "optimizer_params": {"lr": 0.01}}, + init_param={"fit_intercept": True, "method": "zeros"}, train_data=intersect_0.outputs["output_data"]) lr_1 = CoordinatedLR("lr_1", test_data=intersect_0.outputs["output_data"], input_model=lr_0.outputs["output_model"]) @@ -39,31 +39,33 @@ namespace="experiment"))""" evaluation_0 = Evaluation("evaluation_0", + label_column_name="y", runtime_roles=["guest"], + default_eval_setting="multi", input_data=lr_0.outputs["train_output_data"]) # pipeline.add_task(feature_scale_0) # pipeline.add_task(feature_scale_1) pipeline.add_task(intersect_0) pipeline.add_task(lr_0) -pipeline.add_task(evaluation_0) +# pipeline.add_task(evaluation_0) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() print(pipeline.get_dag()) pipeline.fit() +print(f"lr_0 model: {pipeline.get_task_info('lr_0').get_output_model()}") +print(f"train lr_0 data: {pipeline.get_task_info('lr_0').get_output_data()}") # print(pipeline.get_task_info("statistics_0").get_output_model()) -print(pipeline.get_task_info("lr_0").get_output_model()) -print(pipeline.get_task_info("lr_0").get_output_metrics()) -print(f"evaluation metrics: ") -print(pipeline.get_task_info("evaluation_0").get_output_metrics()) +# print(f"evaluation metrics: ") +# print(pipeline.get_task_info("evaluation_0").get_output_metric()) pipeline.deploy([intersect_0, lr_0]) predict_pipeline = FateFlowPipeline() deployed_pipeline = pipeline.get_deployed_pipeline() -deployed_pipeline.intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", +deployed_pipeline.intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest_multi", namespace="experiment_sid")) deployed_pipeline.intersect_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", namespace="experiment_sid")) @@ -73,3 +75,4 @@ # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) predict_pipeline.predict() +print(f"predict lr_0 data: {pipeline.get_task_info('lr_0').get_output_data()}") diff --git a/examples/pipeline/test_single_linr.py b/examples/pipeline/test_single_linr.py index e42bedebb7..ec58f83a78 100644 --- a/examples/pipeline/test_single_linr.py +++ b/examples/pipeline/test_single_linr.py @@ -29,7 +29,7 @@ linr_0 = CoordinatedLinR("linr_0", epochs=10, - batch_size=-1, + batch_size=None, init_param={"fit_intercept": False}) linr_0.guest.component_setting(train_data=DataWarehouseChannel(name="motor_hetero_guest", diff --git a/examples/pipeline/test_single_lr_multi.py b/examples/pipeline/test_single_lr_multi.py index ffee88351e..3dc1a6e41c 100644 --- a/examples/pipeline/test_single_lr_multi.py +++ b/examples/pipeline/test_single_lr_multi.py @@ -29,7 +29,7 @@ lr_0 = CoordinatedLR("lr_0", epochs=10, - batch_size=-1, + batch_size=None, init_param={"fit_intercept": False}) lr_0.guest.component_setting(train_data=DataWarehouseChannel(name="vehicle_scale_hetero_guest", From 034568e08f7c07b9181f9b427434858c0d749a32 Mon Sep 17 00:00:00 2001 From: weiwee Date: Tue, 18 Jul 2023 17:06:40 +0800 Subject: [PATCH 54/61] add index ctx Signed-off-by: weiwee --- python/fate/arch/context/_context.py | 7 ++++-- python/fate/arch/context/_namespace.py | 2 +- .../src/secure_aggregation_helper/mod.rs | 22 +++++++------------ 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/python/fate/arch/context/_context.py b/python/fate/arch/context/_context.py index 476fa47126..1c778c0b68 100644 --- a/python/fate/arch/context/_context.py +++ b/python/fate/arch/context/_context.py @@ -101,8 +101,11 @@ def computing(self): def federation(self) -> "FederationEngine": return self._get_federation() - def sub_ctx(self, name: str, is_special=False) -> "Context": - return self.with_namespace(self._namespace.sub_ns(name=name, is_special=is_special)) + def sub_ctx(self, name: str) -> "Context": + return self.with_namespace(self._namespace.sub_ns(name=name)) + + def indexed_ctx(self, index: int) -> "Context": + return self.with_namespace(self._namespace.indexed_ns(index)) @property def on_iterations(self) -> "Context": diff --git a/python/fate/arch/context/_namespace.py b/python/fate/arch/context/_namespace.py index 5c77e76a0f..ab3b5f54a2 100644 --- a/python/fate/arch/context/_namespace.py +++ b/python/fate/arch/context/_namespace.py @@ -53,7 +53,7 @@ def __str__(self) -> str: def indexed_ns(self, index: int): return IndexedNS(index=index, name=self.name, deep=self.deep, parent=self.parent) - def sub_ns(self, name: str, is_special=False): + def sub_ns(self, name: str): return NS(name=name, deep=self.deep + 1, parent=self) diff --git a/rust/fate_utils/crates/fate_utils/src/secure_aggregation_helper/mod.rs b/rust/fate_utils/crates/fate_utils/src/secure_aggregation_helper/mod.rs index 0c10b9805a..afa5e487f7 100644 --- a/rust/fate_utils/crates/fate_utils/src/secure_aggregation_helper/mod.rs +++ b/rust/fate_utils/crates/fate_utils/src/secure_aggregation_helper/mod.rs @@ -1,20 +1,16 @@ -use core::f32; -use curve25519_dalek::edwards::EdwardsPoint; -use curve25519_dalek::montgomery::MontgomeryPoint; -use curve25519_dalek::scalar::Scalar; +use std::collections::HashMap; + use ndarray; use ndarray::prelude::*; -use numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayDyn, PyReadonlyArray1, PyReadonlyArrayDyn}; +use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn}; use pyo3::exceptions::{PyIndexError, PyTypeError}; use pyo3::prelude::*; -use pyo3::types::{IntoPyDict, PyBytes}; -use pyo3::wrap_pyfunction; +use pyo3::types::PyBytes; use rand::distributions::Uniform; -use rand::SeedableRng; -use rand_core::OsRng; use rand::Rng; +use rand::SeedableRng; use rand_chacha::ChaCha20Rng; -use std::collections::HashMap; +use rand_core::OsRng; use x25519_dalek::{EphemeralSecret, PublicKey}; #[pyclass] @@ -111,11 +107,9 @@ impl RandomMix { } }; let range = Uniform::new(-1e7f64, 1e7f64); - input.as_array() - .iter() - .zip(output_decimal_array.iter_mut()) + output_decimal_array.iter_mut() .zip(output_integer_array.iter_mut()) - .for_each(|((input, output_decimal), output_integer)| { + .for_each(|(output_decimal, output_integer)| { for state in self.states.iter_mut() { let rand = state.random_state.sample(range); state.index += 1; From 11e3d86d3d903b9af59a8d876cabf57a24e215d1 Mon Sep 17 00:00:00 2001 From: cwj Date: Tue, 18 Jul 2023 17:19:45 +0800 Subject: [PATCH 55/61] Adapt new aggregator Signed-off-by: cwj --- .../ml/aggregator/plaintext_aggregator.py | 132 +++++------------- .../ml/aggregator/test/test_aggregator.py | 72 ++++++++++ .../ml/aggregator/test/test_fate_utils.py | 6 +- 3 files changed, 113 insertions(+), 97 deletions(-) create mode 100644 python/fate/ml/aggregator/test/test_aggregator.py diff --git a/python/fate/ml/aggregator/plaintext_aggregator.py b/python/fate/ml/aggregator/plaintext_aggregator.py index 493abfb97f..b8d09cfc04 100644 --- a/python/fate/ml/aggregator/plaintext_aggregator.py +++ b/python/fate/ml/aggregator/plaintext_aggregator.py @@ -4,6 +4,8 @@ from typing import Union from .base import Aggregator import logging +from fate.arch.protocol._dh import SecureAggregatorClient as sa_client +from fate.arch.protocol._dh import SecureAggregatorServer as sa_server logger = logging.getLogger(__name__) @@ -18,12 +20,13 @@ class PlainTextAggregatorClient(Aggregator): PlainTextAggregatorClient is used to aggregate plain text data """ - def __init__(self, ctx: Context, aggregator_name=None, aggregate_type='mean', sample_num=1) -> None: + def __init__(self, ctx: Context, aggregator_name: str = None, aggregate_type='mean', sample_num=1) -> None: super().__init__(ctx, aggregator_name) self.ctx = ctx self._weight = 1.0 - + self.aggregator_name = 'default' if aggregator_name is None else aggregator_name + if sample_num <= 0 and not isinstance(sample_num, int): raise ValueError("sample_num should be int greater than 0") @@ -44,25 +47,29 @@ def __init__(self, ctx: Context, aggregator_name=None, aggregate_type='mean', sa logger.info("aggregate weight is {}".format(self._weight)) + self.model_aggregator = sa_client(prefix=self.aggregator_name+'_model', is_mock=True) + self.model_aggregator.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) + self.loss_aggregator = sa_client(prefix=self.aggregator_name+'_loss', is_mock=True) + self.loss_aggregator.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) + def _process_model(self, model): to_agg = None if isinstance(model, np.ndarray) or isinstance(model, t.Tensor): to_agg = model * self._weight - return to_agg + return [to_agg] if isinstance(model, t.nn.Module): parameters = list(model.parameters()) - tmp_list = [[p.cpu().detach().numpy() for p in parameters if p.requires_grad]] + agg_list = [p.cpu().detach().numpy() for p in parameters if p.requires_grad] + elif isinstance(model, list): for p in model: assert isinstance( p, np.ndarray), 'expecting List[np.ndarray], but got {}'.format(p) - tmp_list = [model] + agg_list = model - to_agg = [[arr * self._weight for arr in arr_list] - for arr_list in tmp_list] - return to_agg + return agg_list def _recover_model(self, model, agg_model): @@ -75,48 +82,25 @@ def _recover_model(self, model, agg_model): else: return agg_model - def _send_loss(self, loss): - assert isinstance(loss, float) or isinstance( - loss, np.ndarray), 'illegal loss type {}, loss should be a float or a np array'.format(type(loss)) - loss_suffix = self.suffix['local_loss']() - self.ctx.arbiter.put(loss_suffix, loss) - - def _send_model(self, model: Union[np.ndarray, t.Tensor, t.nn.Module]): - """Sending model to arbiter for aggregation - - Parameters - ---------- - model : model can be: - A numpy array - A Weight instance(or subclass of Weights), see federatedml.framework.weights - List of numpy array - A pytorch model, is the subclass of torch.nn.Module - A pytorch optimizer, will extract param group from this optimizer as weights to aggregate - """ - # judge model type - to_agg_model = self._process_model(model) - suffix = self.suffix['local_model']() - self.ctx.arbiter.put(suffix, to_agg_model) - - def _get_aggregated_model(self): - return self.ctx.arbiter.get(self.suffix['agg_model']())[0] - - def _get_aggregated_loss(self): - return self.ctx.arbiter.get(self.suffix['agg_loss']())[0] - """ User API """ def model_aggregation(self, model): - self._send_model(model) - agg_model = self._get_aggregated_model() + to_send = self._process_model(model) + print('model is ', to_send) + agg_model = self.model_aggregator.secure_aggregate(self.ctx, to_send, self._weight) return self._recover_model(model, agg_model) def loss_aggregation(self, loss): - self._send_loss(loss) - + if isinstance(loss, t.Tensor): + loss = loss.detach.cpu().numpy() + else: + loss = np.array(loss) + loss = [loss] + agg_loss = self.loss_aggregator.secure_aggregate(self.ctx, loss, self._weight) + return agg_loss class PlainTextAggregatorServer(Aggregator): @@ -125,9 +109,10 @@ class PlainTextAggregatorServer(Aggregator): PlainTextAggregatorServer is used to aggregate plain text data """ - def __init__(self, ctx: Context, aggregator_name=None) -> None: + def __init__(self, ctx: Context, aggregator_name: str = None) -> None: super().__init__(ctx, aggregator_name) + weight_list = self._collect(self.suffix["local_weight"]()) weight_sum = sum(weight_list) ret_weight = [] @@ -138,6 +123,10 @@ def __init__(self, ctx: Context, aggregator_name=None) -> None: for idx, w in enumerate(ret_weight): self._broadcast(w, ret_suffix, idx) + self.aggregator_name = 'default' if aggregator_name is None else aggregator_name + self.model_aggregator = sa_server(prefix=self.aggregator_name+'_model', is_mock=True, ranks=[ctx.guest.rank, *ctx.hosts.ranks]) + self.loss_aggregator = sa_server(prefix=self.aggregator_name+'_loss', is_mock=True, ranks=[ctx.guest.rank, *ctx.hosts.ranks]) + def _check_party_id(self, party_id): # party idx >= -1, int if not isinstance(party_id, int): @@ -145,15 +134,11 @@ def _check_party_id(self, party_id): if party_id < -1: raise ValueError("party_id should be greater than -1") - def _collect(self, suffix, party_idx=-1): - self._check_party_id(party_idx) + def _collect(self, suffix): guest_item = [self.ctx.guest.get(suffix)] host_item = self.ctx.hosts.get(suffix) combine_list = guest_item + host_item - if party_idx == -1: - return combine_list - else: - return combine_list[party_idx] + return combine_list def _broadcast(self, data, suffix, party_idx=-1): self._check_party_id(party_idx) @@ -165,54 +150,13 @@ def _broadcast(self, data, suffix, party_idx=-1): else: self.ctx.hosts[party_idx - 1].put(suffix, data) - def _aggregate_model(self, party_idx=-1): - - # get suffix - suffix = self.suffix['local_model']() - # recv params for aggregation - models = self._collect(suffix=suffix, party_idx=party_idx) - agg_result = None - # Aggregate numpy groups - if isinstance(models[0], list): - # aggregation - agg_result = models[0] - # aggregate numpy model weights from all clients - for params_group in models[1:]: - for agg_params, params in zip( - agg_result, params_group): - for agg_p, p in zip(agg_params, params): - # agg_p: NumpyWeights or numpy array - agg_p += p - else: - raise ValueError('invalid aggregation format: {}'.format(models)) - - if agg_result is None: - raise ValueError( - 'can not aggregate receive model, format is illegal: {}'.format(models)) - - return agg_result - - def _aggregate_loss(self, party_idx=-1): - - # get loss - loss_suffix = self.suffix['local_loss']() - losses = self._collect(suffix=loss_suffix, party_idx=-1) - total_loss = losses[0] - for loss in losses[1:]: - total_loss += loss - - return total_loss - """ User API """ - def model_aggregation(self, party_idx=-1): - agg_model = self._aggregate_model(party_idx=party_idx) - suffix = self.suffix['agg_model']() - self._broadcast(agg_model, suffix=suffix, party_idx=party_idx) - return agg_model + def model_aggregation(self, ranks=None): + self.model_aggregator.secure_aggregate(self.ctx, ranks=ranks) - def loss_aggregation(self, party_idx=-1): - agg_loss = self._aggregate_loss(party_idx=party_idx) - return agg_loss \ No newline at end of file + def loss_aggregation(self, ranks=None): + self.loss_aggregator.secure_aggregate(self.ctx, ranks=ranks) + \ No newline at end of file diff --git a/python/fate/ml/aggregator/test/test_aggregator.py b/python/fate/ml/aggregator/test/test_aggregator.py new file mode 100644 index 0000000000..177e8e210c --- /dev/null +++ b/python/fate/ml/aggregator/test/test_aggregator.py @@ -0,0 +1,72 @@ +import sys +import torch as t + + +arbiter = ("arbiter", 10000) +guest = ("guest", 10000) +host = ("host", 9999) +name = "fed" + + +def create_ctx(local): + from fate.arch import Context + from fate.arch.computing.standalone import CSession + from fate.arch.federation.standalone import StandaloneFederation + import logging + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + computing = CSession() + return Context(computing=computing, + federation=StandaloneFederation(computing, name, local, [guest, host, arbiter])) + + +if __name__ == "__main__": + + epoch = 10 + + if sys.argv[1] == "guest": + from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorClient + + ctx = create_ctx(guest) + client = PlainTextAggregatorClient(ctx, sample_num=100, aggregate_type='weighted_mean') + model = t.nn.Sequential( + t.nn.Linear(10, 10), + t.nn.ReLU(), + t.nn.Linear(10, 1), + t.nn.Sigmoid() + ) + + for i in range(epoch): + client.model_aggregation(model) + elif sys.argv[1] == "host": + from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorClient + + ctx = create_ctx(host) + client = PlainTextAggregatorClient(ctx, sample_num=100, aggregate_type='weighted_mean') + model = t.nn.Sequential( + t.nn.Linear(10, 10), + t.nn.ReLU(), + t.nn.Linear(10, 1), + t.nn.Sigmoid() + ) + + for i in range(epoch): + client.model_aggregation(model) + + else: + + from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorServer + ctx = create_ctx(arbiter) + server = PlainTextAggregatorServer(ctx) + + for i in range(epoch): + server.model_aggregation() + diff --git a/python/fate/ml/aggregator/test/test_fate_utils.py b/python/fate/ml/aggregator/test/test_fate_utils.py index 4a3c2fab6b..51ddd34311 100644 --- a/python/fate/ml/aggregator/test/test_fate_utils.py +++ b/python/fate/ml/aggregator/test/test_fate_utils.py @@ -32,7 +32,7 @@ def create_ctx(local): import numpy as np ctx = create_ctx(guest) - client = SecureAggregatorClient() + client = SecureAggregatorClient(is_mock=True) client.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) print('ranks are {}'.format([ctx.guest.rank, *ctx.hosts.ranks])) print(client.secure_aggregate(ctx, [np.zeros((3, 4)), np.ones((2, 3))])) @@ -41,12 +41,12 @@ def create_ctx(local): import numpy as np ctx = create_ctx(host) - client = SecureAggregatorClient() + client = SecureAggregatorClient(is_mock=True) client.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) print(client.secure_aggregate(ctx, [np.zeros((3, 4)), np.ones((2, 3))])) else: from fate.arch.protocol import SecureAggregatorServer ctx = create_ctx(arbiter) - server = SecureAggregatorServer([ctx.guest.rank, *ctx.hosts.ranks]) + server = SecureAggregatorServer([ctx.guest.rank, *ctx.hosts.ranks], is_mock=True) server.secure_aggregate(ctx) \ No newline at end of file From 7b34ea85ec7a0f63549489a228452a8d7d8ec0b5 Mon Sep 17 00:00:00 2001 From: weiwee Date: Tue, 18 Jul 2023 17:38:34 +0800 Subject: [PATCH 56/61] add feature names Signed-off-by: weiwee --- python/fate/arch/dataframe/ops/_histogram.py | 6 ++++-- python/fate/arch/tensor/inside/_op_hist.py | 8 ++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/python/fate/arch/dataframe/ops/_histogram.py b/python/fate/arch/dataframe/ops/_histogram.py index 1c26eafa5e..d8baa579e1 100644 --- a/python/fate/arch/dataframe/ops/_histogram.py +++ b/python/fate/arch/dataframe/ops/_histogram.py @@ -16,19 +16,21 @@ import functools from fate.arch.tensor.inside import Hist -from ._compress_block import compress_blocks + from .._dataframe import DataFrame from ..manager import BlockType, DataManager +from ._compress_block import compress_blocks def hist(df: DataFrame, targets): data_manager = df.data_manager + column_names = data_manager.infer_operable_field_names() block_table, data_manager = _try_to_compress_table(df.block_table, data_manager) block_id = data_manager.infer_operable_blocks()[0] def _mapper(blocks, target, bid: int = None): - histogram = Hist() + histogram = Hist(column_names) histogram.update(blocks[bid], target) return histogram diff --git a/python/fate/arch/tensor/inside/_op_hist.py b/python/fate/arch/tensor/inside/_op_hist.py index 7bd4a03c7f..748c524e4d 100644 --- a/python/fate/arch/tensor/inside/_op_hist.py +++ b/python/fate/arch/tensor/inside/_op_hist.py @@ -2,7 +2,8 @@ class Hist: - def __init__(self): + def __init__(self, feature_names): + self.feature_names = feature_names self.data: typing.Dict[int, typing.Dict[int, typing.Any]] = {} def update(self, features, labels): @@ -38,7 +39,7 @@ def cumsum(self): return self def __sub__(self, other: "Hist"): - out = Hist() + out = Hist(self.feature_names) for j in self.data: out.data[j] = {} for v in self.data[j]: @@ -48,6 +49,9 @@ def __sub__(self, other: "Hist"): out.data[j][v] = self.data[j][v] - other.data[j][v] return out + def to_dict(self): + return {name: self.data[i] for i, name in enumerate(self.feature_names)} + if __name__ == "__main__": import numpy as np From a40588df2260e73c2d7022ce99372e5c4b1ca6f7 Mon Sep 17 00:00:00 2001 From: weiwee Date: Tue, 18 Jul 2023 19:45:13 +0800 Subject: [PATCH 57/61] fix hist Signed-off-by: weiwee --- python/fate/arch/tensor/inside/_op_hist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/fate/arch/tensor/inside/_op_hist.py b/python/fate/arch/tensor/inside/_op_hist.py index 748c524e4d..623db546e9 100644 --- a/python/fate/arch/tensor/inside/_op_hist.py +++ b/python/fate/arch/tensor/inside/_op_hist.py @@ -10,7 +10,7 @@ def update(self, features, labels): shape_x, shape_y = features.shape for i in range(shape_x): for j in range(shape_y): - v = features[i, j] + v = features[i, j].item() if j not in self.data: self.data[j] = {} if v not in self.data[j]: From 743c7f0ff2c4cc8df05ac6b9adf5e3787d5df641 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Wed, 19 Jul 2023 11:53:45 +0800 Subject: [PATCH 58/61] fix model export when early stop=weight diff(#4659) Signed-off-by: Yu Wu --- python/fate/ml/utils/_convergence.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/fate/ml/utils/_convergence.py b/python/fate/ml/utils/_convergence.py index 153d529570..05498ecf52 100644 --- a/python/fate/ml/utils/_convergence.py +++ b/python/fate/ml/utils/_convergence.py @@ -78,7 +78,10 @@ def __init__(self, eps): def is_converge(self, delta_weight, weight=None): weight_diff = torch.linalg.norm(delta_weight, 2) if weight is None: - return weight_diff < self.eps + # avoid tensor[bool] + if weight_diff < self.eps: + return True + return False if self.pre_weight is None: self.pre_weight = weight return False From ce3fec22221cd29bbe14b9887165aea42e020ade Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Wed, 19 Jul 2023 11:58:55 +0800 Subject: [PATCH 59/61] edit example(#4659) Signed-off-by: Yu Wu --- examples/pipeline/test_lr_sid.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/pipeline/test_lr_sid.py b/examples/pipeline/test_lr_sid.py index b5fba7ae8f..e8569d3b1a 100644 --- a/examples/pipeline/test_lr_sid.py +++ b/examples/pipeline/test_lr_sid.py @@ -20,7 +20,7 @@ pipeline = FateFlowPipeline().set_roles(guest="9999", host="9998", arbiter="9998") intersect_0 = Intersection("intersect_0", method="raw") -intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest_multi", +intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", namespace="experiment_sid")) intersect_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", namespace="experiment_sid")) @@ -41,7 +41,7 @@ evaluation_0 = Evaluation("evaluation_0", label_column_name="y", runtime_roles=["guest"], - default_eval_setting="multi", + default_eval_setting="binary", input_data=lr_0.outputs["train_output_data"]) # pipeline.add_task(feature_scale_0) @@ -65,7 +65,7 @@ predict_pipeline = FateFlowPipeline() deployed_pipeline = pipeline.get_deployed_pipeline() -deployed_pipeline.intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest_multi", +deployed_pipeline.intersect_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", namespace="experiment_sid")) deployed_pipeline.intersect_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", namespace="experiment_sid")) From ed8407a7bb768a29d1659f39d077fe1fda65ed73 Mon Sep 17 00:00:00 2001 From: cwj Date: Wed, 19 Jul 2023 15:15:23 +0800 Subject: [PATCH 60/61] 1. Add New Secure Aggregate 2. Support Secure Agg Aggregator & Plaintext Aggregator 3. Float64 convert in Aggregator Signed-off-by: cwj --- python/fate/ml/aggregator/__init__.py | 17 ++ python/fate/ml/aggregator/base.py | 187 ++++++++++++++++- .../ml/aggregator/plaintext_aggregator.py | 158 +-------------- .../fate/ml/aggregator/secure_aggregator.py | 22 +- .../ml/aggregator/test/test_aggregator.py | 13 +- .../fate/ml/glm/homo/lr/test/test_fed_lr.py | 71 +++++++ .../lr/test/{local_test.py => test_local.py} | 3 +- python/fate/ml/nn/algo/homo/fedavg.py | 61 +++--- python/fate/ml/nn/dataset/table.py | 7 +- .../fate/ml/nn/trainer/test/test_trainer.py | 63 ++++++ python/fate/ml/nn/trainer/trainer_base.py | 190 +++++++++--------- 11 files changed, 492 insertions(+), 300 deletions(-) create mode 100644 python/fate/ml/aggregator/__init__.py create mode 100644 python/fate/ml/glm/homo/lr/test/test_fed_lr.py rename python/fate/ml/glm/homo/lr/test/{local_test.py => test_local.py} (95%) create mode 100644 python/fate/ml/nn/trainer/test/test_trainer.py diff --git a/python/fate/ml/aggregator/__init__.py b/python/fate/ml/aggregator/__init__.py new file mode 100644 index 0000000000..4909859034 --- /dev/null +++ b/python/fate/ml/aggregator/__init__.py @@ -0,0 +1,17 @@ +from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorClient, PlainTextAggregatorServer +from fate.ml.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer +import enum + + +class AggregatorType(enum.Enum): + PLAINTEXT = 'plaintext' + SECURE_AGGREGATE = 'secure_aggregate' + + +aggregator_map = { + AggregatorType.PLAINTEXT.value: (PlainTextAggregatorClient, PlainTextAggregatorServer), + AggregatorType.SECURE_AGGREGATE.value: (SecureAggregatorClient, SecureAggregatorServer) +} + + +__all__ = ['PlainTextAggregatorClient', 'PlainTextAggregatorServer', 'SecureAggregatorClient', 'SecureAggregatorServer'] diff --git a/python/fate/ml/aggregator/base.py b/python/fate/ml/aggregator/base.py index 535fd6ea96..cb9683269b 100644 --- a/python/fate/ml/aggregator/base.py +++ b/python/fate/ml/aggregator/base.py @@ -1,10 +1,18 @@ from fate.arch import Context from typing import Optional import logging +import numpy as np +from fate.arch.protocol._dh import SecureAggregatorClient as sa_client +from fate.arch.protocol._dh import SecureAggregatorServer as sa_server +import torch as t + logger = logging.getLogger(__name__) +AGGREGATE_TYPE = ['mean', 'sum', 'weighted_mean'] +TORCH_TENSOR_PRECISION = ['float32', 'float64'] + class AutoSuffix(object): @@ -25,7 +33,7 @@ def __call__(self): class Aggregator: def __init__(self, ctx: Context, aggregator_name: Optional[str] = None): - self.ctx = ctx + if aggregator_name is not None: agg_name = "_" + aggregator_name else: @@ -45,3 +53,180 @@ def model_aggregation(self, *args, **kwargs): def loss_aggregation(self, *args, **kwargs): raise NotImplementedError("loss_aggregation should be implemented in subclass") + + + +class BaseAggregatorClient(Aggregator): + + + def __init__(self, ctx: Context, aggregator_name: str = None, aggregate_type='mean', sample_num=1, is_mock=True, require_grad=True, float_p='float64') -> None: + + super().__init__(ctx, aggregator_name) + self._weight = 1.0 + self.aggregator_name = 'default' if aggregator_name is None else aggregator_name + self.require_grad = require_grad + + assert float_p in TORCH_TENSOR_PRECISION, "float_p should be one of {}".format(TORCH_TENSOR_PRECISION) + self.float_p = float_p + + if sample_num <= 0 and not isinstance(sample_num, int): + raise ValueError("sample_num should be int greater than 0") + + logger.info('computing weights') + if aggregate_type not in AGGREGATE_TYPE: + raise ValueError("aggregate_type should be one of {}".format(AGGREGATE_TYPE)) + elif aggregate_type == 'mean': + ctx.arbiter.put(self.suffix["local_weight"](), 1.0) + self._weight = ctx.arbiter.get(self.suffix["computed_weight"]()) + elif aggregate_type == 'sum': + ctx.arbiter.put(self.suffix["local_weight"](), sample_num) + self._weight = 1.0 + elif aggregate_type == 'weighted_mean': + if sample_num <= 0 or sample_num is None: + raise ValueError("sample_num should be int greater than 0") + ctx.arbiter.put(self.suffix["local_weight"](), sample_num) + self._weight = ctx.arbiter.get(self.suffix["computed_weight"]()) + + logger.info("aggregate weight is {}".format(self._weight)) + + self.model_aggregator = sa_client(prefix=self.aggregator_name+'_model', is_mock=is_mock) + self.model_aggregator.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) + self.loss_aggregator = sa_client(prefix=self.aggregator_name+'_loss', is_mock=is_mock) + self.loss_aggregator.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) + + + def _convert_type(self, data, dtype='float32'): + + if isinstance(data, t.Tensor): + if dtype == 'float32': + data = data.float() + elif dtype == 'float64': + data = data.double() + else: + raise ValueError("Invalid dtype. Choose either 'float32' or 'float64'") + + numpy_array = data.detach().cpu().numpy() + + elif isinstance(data, np.ndarray): + if dtype == 'float32': + numpy_array = data.astype(np.float32) + elif dtype == 'float64': + numpy_array = data.astype(np.float64) + else: + raise ValueError("Invalid dtype. Choose either 'float32' or 'float64'") + else: + raise ValueError("Invalid data type. Only numpy ndarray and PyTorch tensor are supported.") + + return numpy_array + + def _process_model(self, model): + + to_agg = None + if isinstance(model, np.ndarray) or isinstance(model, t.Tensor): + to_agg = self._convert_type(model, self.float_p) + return [to_agg] + + if isinstance(model, t.nn.Module): + parameters = list(model.parameters()) + if self.require_grad: + agg_list = [self._convert_type(p.cpu().detach().numpy(), self.float_p) for p in parameters if p.requires_grad] + else: + agg_list = [self._convert_type(p.cpu().detach().numpy(), self.float_p) for p in parameters] + + elif isinstance(model, list): + to_agg = [] + for p in model: + to_agg.append(self._convert_type(p, self.float_p)) + agg_list = to_agg + + return agg_list + + def _recover_model(self, model, agg_model): + + if isinstance(model, np.ndarray) or isinstance(model, t.Tensor): + return agg_model + elif isinstance(model, t.nn.Module): + if self.require_grad: + for agg_p, p in zip(agg_model, [p for p in model.parameters() if p.requires_grad]): + p.data.copy_(t.Tensor(agg_p)) + else: + for agg_p, p in zip(agg_model, model.parameters()): + p.data.copy_(t.Tensor(agg_p)) + return model + else: + return agg_model + + """ + User API + """ + + def model_aggregation(self, ctx, model): + + to_send = self._process_model(model) + agg_model = self.model_aggregator.secure_aggregate(ctx, to_send, self._weight) + return self._recover_model(model, agg_model) + + def loss_aggregation(self, ctx, loss): + if isinstance(loss, t.Tensor): + loss = loss.detach.cpu().numpy() + else: + loss = np.array(loss) + loss = [loss] + agg_loss = self.loss_aggregator.secure_aggregate(ctx, loss, self._weight) + return agg_loss + + +class BaseAggregatorServer(Aggregator): + + + def __init__(self, ctx: Context, aggregator_name: str = None, is_mock=True) -> None: + + super().__init__(ctx, aggregator_name) + + weight_list = self._collect(ctx, self.suffix["local_weight"]()) + weight_sum = sum(weight_list) + ret_weight = [] + for w in weight_list: + ret_weight.append(w / weight_sum) + + ret_suffix = self.suffix["computed_weight"]() + for idx, w in enumerate(ret_weight): + self._broadcast(ctx, w, ret_suffix, idx) + + self.aggregator_name = 'default' if aggregator_name is None else aggregator_name + self.model_aggregator = sa_server(prefix=self.aggregator_name+'_model', is_mock=is_mock, ranks=[ctx.guest.rank, *ctx.hosts.ranks]) + self.loss_aggregator = sa_server(prefix=self.aggregator_name+'_loss', is_mock=is_mock, ranks=[ctx.guest.rank, *ctx.hosts.ranks]) + + def _check_party_id(self, party_id): + # party idx >= -1, int + if not isinstance(party_id, int): + raise ValueError("party_id should be int") + if party_id < -1: + raise ValueError("party_id should be greater than -1") + + def _collect(self, ctx, suffix): + guest_item = [ctx.guest.get(suffix)] + host_item = ctx.hosts.get(suffix) + combine_list = guest_item + host_item + return combine_list + + def _broadcast(self, ctx, data, suffix, party_idx=-1): + self._check_party_id(party_idx) + if party_idx == -1: + ctx.guest.put(suffix, data) + ctx.hosts.put(suffix, data) + elif party_idx == 0: + ctx.guest.put(suffix, data) + else: + ctx.hosts[party_idx - 1].put(suffix, data) + + """ + User API + """ + + def model_aggregation(self, ctx, ranks=None): + self.model_aggregator.secure_aggregate(ctx, ranks=ranks) + + def loss_aggregation(self, ctx, ranks=None): + self.loss_aggregator.secure_aggregate(ctx, ranks=ranks) + \ No newline at end of file diff --git a/python/fate/ml/aggregator/plaintext_aggregator.py b/python/fate/ml/aggregator/plaintext_aggregator.py index b8d09cfc04..81fc85d5d2 100644 --- a/python/fate/ml/aggregator/plaintext_aggregator.py +++ b/python/fate/ml/aggregator/plaintext_aggregator.py @@ -1,162 +1,14 @@ -import torch as t -import numpy as np from fate.arch import Context -from typing import Union -from .base import Aggregator -import logging -from fate.arch.protocol._dh import SecureAggregatorClient as sa_client -from fate.arch.protocol._dh import SecureAggregatorServer as sa_server +from fate.ml.aggregator.base import BaseAggregatorClient, BaseAggregatorServer -logger = logging.getLogger(__name__) - - -AGGREGATE_TYPE = ['mean', 'sum', 'weighted_mean'] - - -class PlainTextAggregatorClient(Aggregator): - - """ - PlainTextAggregatorClient is used to aggregate plain text data - """ +class PlainTextAggregatorClient(BaseAggregatorClient): def __init__(self, ctx: Context, aggregator_name: str = None, aggregate_type='mean', sample_num=1) -> None: - - super().__init__(ctx, aggregator_name) - self.ctx = ctx - self._weight = 1.0 - self.aggregator_name = 'default' if aggregator_name is None else aggregator_name - - if sample_num <= 0 and not isinstance(sample_num, int): - raise ValueError("sample_num should be int greater than 0") - - logger.info('computing weights') - if aggregate_type not in AGGREGATE_TYPE: - raise ValueError("aggregate_type should be one of {}".format(AGGREGATE_TYPE)) - elif aggregate_type == 'mean': - self.ctx.arbiter.put(self.suffix["local_weight"](), 1.0) - self._weight = self.ctx.arbiter.get(self.suffix["computed_weight"]()) - elif aggregate_type == 'sum': - self.ctx.arbiter.put(self.suffix["local_weight"](), sample_num) - self._weight = 1.0 - elif aggregate_type == 'weighted_mean': - if sample_num <= 0 or sample_num is None: - raise ValueError("sample_num should be int greater than 0") - self.ctx.arbiter.put(self.suffix["local_weight"](), sample_num) - self._weight = self.ctx.arbiter.get(self.suffix["computed_weight"]()) - - logger.info("aggregate weight is {}".format(self._weight)) - - self.model_aggregator = sa_client(prefix=self.aggregator_name+'_model', is_mock=True) - self.model_aggregator.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) - self.loss_aggregator = sa_client(prefix=self.aggregator_name+'_loss', is_mock=True) - self.loss_aggregator.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) - - def _process_model(self, model): - - to_agg = None - if isinstance(model, np.ndarray) or isinstance(model, t.Tensor): - to_agg = model * self._weight - return [to_agg] - - if isinstance(model, t.nn.Module): - parameters = list(model.parameters()) - agg_list = [p.cpu().detach().numpy() for p in parameters if p.requires_grad] - - elif isinstance(model, list): - for p in model: - assert isinstance( - p, np.ndarray), 'expecting List[np.ndarray], but got {}'.format(p) - agg_list = model - - return agg_list + super().__init__(ctx, aggregator_name, aggregate_type, sample_num, is_mock=True) - def _recover_model(self, model, agg_model): - - if isinstance(model, np.ndarray) or isinstance(model, t.Tensor): - return agg_model - elif isinstance(model, t.nn.Module): - for agg_p, p in zip(agg_model, [p for p in model.parameters() if p.requires_grad]): - p.data.copy_(t.Tensor(agg_p)) - return model - else: - return agg_model - """ - User API - """ - - def model_aggregation(self, model): - - to_send = self._process_model(model) - print('model is ', to_send) - agg_model = self.model_aggregator.secure_aggregate(self.ctx, to_send, self._weight) - return self._recover_model(model, agg_model) - - def loss_aggregation(self, loss): - if isinstance(loss, t.Tensor): - loss = loss.detach.cpu().numpy() - else: - loss = np.array(loss) - loss = [loss] - agg_loss = self.loss_aggregator.secure_aggregate(self.ctx, loss, self._weight) - return agg_loss - - -class PlainTextAggregatorServer(Aggregator): - - """ - PlainTextAggregatorServer is used to aggregate plain text data - """ +class PlainTextAggregatorServer(BaseAggregatorServer): def __init__(self, ctx: Context, aggregator_name: str = None) -> None: - - super().__init__(ctx, aggregator_name) - - weight_list = self._collect(self.suffix["local_weight"]()) - weight_sum = sum(weight_list) - ret_weight = [] - for w in weight_list: - ret_weight.append(w / weight_sum) - - ret_suffix = self.suffix["computed_weight"]() - for idx, w in enumerate(ret_weight): - self._broadcast(w, ret_suffix, idx) - - self.aggregator_name = 'default' if aggregator_name is None else aggregator_name - self.model_aggregator = sa_server(prefix=self.aggregator_name+'_model', is_mock=True, ranks=[ctx.guest.rank, *ctx.hosts.ranks]) - self.loss_aggregator = sa_server(prefix=self.aggregator_name+'_loss', is_mock=True, ranks=[ctx.guest.rank, *ctx.hosts.ranks]) - - def _check_party_id(self, party_id): - # party idx >= -1, int - if not isinstance(party_id, int): - raise ValueError("party_id should be int") - if party_id < -1: - raise ValueError("party_id should be greater than -1") - - def _collect(self, suffix): - guest_item = [self.ctx.guest.get(suffix)] - host_item = self.ctx.hosts.get(suffix) - combine_list = guest_item + host_item - return combine_list - - def _broadcast(self, data, suffix, party_idx=-1): - self._check_party_id(party_idx) - if party_idx == -1: - self.ctx.guest.put(suffix, data) - self.ctx.hosts.put(suffix, data) - elif party_idx == 0: - self.ctx.guest.put(suffix, data) - else: - self.ctx.hosts[party_idx - 1].put(suffix, data) - - """ - User API - """ - - def model_aggregation(self, ranks=None): - self.model_aggregator.secure_aggregate(self.ctx, ranks=ranks) - - def loss_aggregation(self, ranks=None): - self.loss_aggregator.secure_aggregate(self.ctx, ranks=ranks) - \ No newline at end of file + super().__init__(ctx, aggregator_name, is_mock=True) \ No newline at end of file diff --git a/python/fate/ml/aggregator/secure_aggregator.py b/python/fate/ml/aggregator/secure_aggregator.py index 08d38db459..0ec0ad37dd 100644 --- a/python/fate/ml/aggregator/secure_aggregator.py +++ b/python/fate/ml/aggregator/secure_aggregator.py @@ -1,22 +1,14 @@ from fate.arch import Context +from fate.ml.aggregator.base import BaseAggregatorClient, BaseAggregatorServer -class SecureAggregatorClient(object): +class SecureAggregatorClient(BaseAggregatorClient): - def __init__(self, ctx: Context, aggregate_type='weighted_mean', aggregate_weight=1.0) -> None: - self.aggregate_type = aggregate_type - self.aggregate_weight = aggregate_weight - self.ctx = ctx + def __init__(self, ctx: Context, aggregator_name: str = None, aggregate_type='mean', sample_num=1) -> None: + super().__init__(ctx, aggregator_name, aggregate_type, sample_num, is_mock=False) - def aggregate(self): - pass +class SecureAggregatorServer(BaseAggregatorServer): -class SecureAggregatorServer(object): - - def __init__(self, ctx: Context) -> None: - pass - - def aggregate(self): - pass - + def __init__(self, ctx: Context, aggregator_name: str = None) -> None: + super().__init__(ctx, aggregator_name, is_mock=False) diff --git a/python/fate/ml/aggregator/test/test_aggregator.py b/python/fate/ml/aggregator/test/test_aggregator.py index 177e8e210c..ba74c743c6 100644 --- a/python/fate/ml/aggregator/test/test_aggregator.py +++ b/python/fate/ml/aggregator/test/test_aggregator.py @@ -44,8 +44,9 @@ def create_ctx(local): t.nn.Sigmoid() ) - for i in range(epoch): - client.model_aggregation(model) + for i, iter_ctx in ctx.on_iterations.ctxs_range(epoch): + client.model_aggregation(iter_ctx, model) + elif sys.argv[1] == "host": from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorClient @@ -58,8 +59,8 @@ def create_ctx(local): t.nn.Sigmoid() ) - for i in range(epoch): - client.model_aggregation(model) + for i, iter_ctx in ctx.on_iterations.ctxs_range(epoch): + client.model_aggregation(iter_ctx, model) else: @@ -67,6 +68,6 @@ def create_ctx(local): ctx = create_ctx(arbiter) server = PlainTextAggregatorServer(ctx) - for i in range(epoch): - server.model_aggregation() + for i, iter_ctx in ctx.on_iterations.ctxs_range(epoch): + server.model_aggregation(iter_ctx) diff --git a/python/fate/ml/glm/homo/lr/test/test_fed_lr.py b/python/fate/ml/glm/homo/lr/test/test_fed_lr.py new file mode 100644 index 0000000000..f4ed496c93 --- /dev/null +++ b/python/fate/ml/glm/homo/lr/test/test_fed_lr.py @@ -0,0 +1,71 @@ +from fate.ml.nn.algo.homo.fedavg import FedAVGCLient, FedArguments, TrainingArguments, FedAVGServer +import torch as t +import pandas as pd +import sys +from fate.arch.dataframe import PandasReader +from fate.ml.glm.homo.lr.client import HomoLRClient +from fate.ml.glm.homo.lr.server import HomoLRServer + + +arbiter = ("arbiter", 10000) +guest = ("guest", 10000) +host = ("host", 9999) +name = "fed" + + +def create_ctx(local): + from fate.arch import Context + from fate.arch.computing.standalone import CSession + from fate.arch.federation.standalone import StandaloneFederation + import logging + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + computing = CSession() + return Context(computing=computing, + federation=StandaloneFederation(computing, name, local, [guest, host, arbiter])) + + + +if __name__ == "__main__": + + if sys.argv[1] == "guest": + + ctx = create_ctx(guest) + df = pd.read_csv( + '../../../../../../../examples/data/breast_homo_guest.csv') + df['sample_id'] = [i for i in range(len(df))] + + reader = PandasReader( + sample_id_name='sample_id', + match_id_name="id", + label_name="y", + dtype="object") + + data = reader.to_frame(ctx, df) + + + elif sys.argv[1] == "host": + + ctx = create_ctx(guest) + df = pd.read_csv( + '../../../../../../../examples/data/breast_homo_host.csv') + df['sample_id'] = [i for i in range(len(df))] + + reader = PandasReader( + sample_id_name='sample_id', + match_id_name="id", + label_name="y", + dtype="object") + + data = reader.to_frame(ctx, df) + + else: + ctx = create_ctx(arbiter) \ No newline at end of file diff --git a/python/fate/ml/glm/homo/lr/test/local_test.py b/python/fate/ml/glm/homo/lr/test/test_local.py similarity index 95% rename from python/fate/ml/glm/homo/lr/test/local_test.py rename to python/fate/ml/glm/homo/lr/test/test_local.py index 2572f129ed..05e6608fda 100644 --- a/python/fate/ml/glm/homo/lr/test/local_test.py +++ b/python/fate/ml/glm/homo/lr/test/test_local.py @@ -24,7 +24,7 @@ computing, "fed", ("guest", 10000), [("host", 9999)]), ) df = pd.read_csv( - '/home/cwj/FATE/FATE-2.0-pure/FATE/examples/data/breast_homo_guest.csv') + '../../../../../../../examples/data/breast_homo_guest.csv') df['sample_id'] = [i for i in range(len(df))] reader = PandasReader( @@ -43,6 +43,7 @@ ds = TableDataset(return_dict=True, to_tensor=True) ds.load(data) + client = HomoLRClient( 50, 800, optimizer_param={ 'method': 'adam', 'penalty': 'l1', 'aplha': 0.1, 'optimizer_para': { diff --git a/python/fate/ml/nn/algo/homo/fedavg.py b/python/fate/ml/nn/algo/homo/fedavg.py index debc91fcb4..d4c6aafb7f 100644 --- a/python/fate/ml/nn/algo/homo/fedavg.py +++ b/python/fate/ml/nn/algo/homo/fedavg.py @@ -1,11 +1,10 @@ from transformers.training_args import TrainingArguments -from fate.ml.aggregator.base import Aggregator from fate.ml.nn.trainer.trainer_base import FedTrainerClient, FedTrainerServer, TrainingArguments -from fate.ml.nn.trainer.trainer_base import FedArguments, time_decorator, TrainingArguments +from fate.ml.nn.trainer.trainer_base import FedArguments, TrainingArguments from dataclasses import field from dataclasses import dataclass, field from dataclasses import dataclass -from typing import List, Optional, Tuple, Callable +from typing import List, Optional, Tuple, Callable, Union from fate.arch import Context from torch.optim import Optimizer from torch.utils.data import Dataset @@ -14,10 +13,15 @@ from torch.nn import Module from torch import nn from torch.utils.data import DataLoader -from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorClient, PlainTextAggregatorServer +from fate.ml.aggregator import PlainTextAggregatorClient, SecureAggregatorClient +from fate.ml.aggregator import PlainTextAggregatorServer, SecureAggregatorServer from transformers import TrainerState, TrainerControl, PreTrainedTokenizer +from fate.ml.aggregator import AggregatorType, aggregator_map +import logging +logger = logging.getLogger(__name__) + @dataclass class FedAVGArguments(FedArguments): @@ -71,20 +75,24 @@ def __init__(self, compute_metrics=compute_metrics, local_mode=local_mode) - def init_aggregator(self): + def init_aggregator(self, ctx: Context, fed_args: FedArguments): + + aggregate_type = 'weighted_mean' + aggregator_name = 'fedavg' + aggregator = fed_args.aggregator + assert aggregator in {item.value for item in AggregatorType}, f"aggregator should be one of {{item.value for item in AggregatorType}}, but got {aggregator}" + client_class = aggregator_map[aggregator][0] + logger.info(f"Using {aggregator} aggregator") sample_num = len(self.train_dataset) - aggregator = PlainTextAggregatorClient( - self.ctx, - aggregator_name='fed_avg', - aggregate_type='weighted_mean', - sample_num=sample_num) + ctx.arbiter.put('agg_type', aggregator) + aggregator = client_class(ctx, aggregate_type=aggregate_type, aggregator_name=aggregator_name, sample_num=sample_num) + return aggregator - @time_decorator('FedAVG') def on_federation( self, ctx: Context, - aggregator: PlainTextAggregatorClient, + aggregator: Union[PlainTextAggregatorClient, SecureAggregatorClient], fed_args: FedArguments, args: TrainingArguments, model: Optional[nn.Module] = None, @@ -95,33 +103,38 @@ def on_federation( state: Optional[TrainerState] = None, **kwargs): - aggregator.model_aggregation(model) + aggregator.model_aggregation(ctx, model) class FedAVGServer(FedTrainerServer): def __init__(self, ctx: Context, - training_args: TrainingArguments = None, - fed_args: FedArguments = None, - parameter_alignment: bool = True, local_mode: bool = False ) -> None: - super().__init__(ctx, training_args, fed_args, parameter_alignment, local_mode) + super().__init__(ctx, local_mode) - def init_aggregator(self): - aggregator = PlainTextAggregatorServer( - self.ctx, aggregator_name='fed_avg') + def init_aggregator(self, ctx): + + aggregator = [ctx.guest.get('agg_type')] + aggregator.extend(ctx.hosts.get('agg_type')) + aggregator = set(aggregator) + if len(aggregator) > 1: + raise ValueError('Aggregator type should be the same between clients, but got {}'.format(aggregator)) + aggregator = aggregator.pop() + aggregator_name = 'fedavg' + aggregator_server = aggregator_map[aggregator][1] + logger.info(f"Using {aggregator} aggregator") + aggregator = aggregator_server(ctx, aggregator_name=aggregator_name) return aggregator def on_federation( self, ctx: Context, - aggregator: Aggregator, - fed_args: FedArguments, - args: TrainingArguments): - aggregator.model_aggregation() + aggregator: Union[SecureAggregatorServer, PlainTextAggregatorServer]): + + aggregator.model_aggregation(ctx) class FedAVG(object): diff --git a/python/fate/ml/nn/dataset/table.py b/python/fate/ml/nn/dataset/table.py index ec6cc6cb9e..3eff8aac78 100644 --- a/python/fate/ml/nn/dataset/table.py +++ b/python/fate/ml/nn/dataset/table.py @@ -156,9 +156,9 @@ def load(self, data_or_path): for i in label_col_candidates: if i in self.origin_table: label = i + logger.info('use "{}" as label column'.format(label)) break if label is None: - self.with_label = False logger.info( 'found no "y"/"label"/"target" in input table, no label will be set') else: @@ -166,9 +166,8 @@ def load(self, data_or_path): raise ValueError( "label column {} not found in input table".format(label)) - if self.label is not None: - self.label = self.origin_table[[label]].values - self.origin_table = self.origin_table.drop(columns=[label]) + self.label = self.origin_table[[label]].values + self.origin_table = self.origin_table.drop(columns=[label]) self.features = self.origin_table.values elif isinstance(data_or_path, DataFrame): diff --git a/python/fate/ml/nn/trainer/test/test_trainer.py b/python/fate/ml/nn/trainer/test/test_trainer.py new file mode 100644 index 0000000000..01f9ee627a --- /dev/null +++ b/python/fate/ml/nn/trainer/test/test_trainer.py @@ -0,0 +1,63 @@ +from fate.ml.nn.algo.homo.fedavg import FedAVGCLient, FedArguments, TrainingArguments, FedAVGServer +import torch as t +import pandas as pd +from fate.ml.nn.dataset.table import TableDataset +import sys + + +arbiter = ("arbiter", 10000) +guest = ("guest", 10000) +host = ("host", 9999) +name = "fed" + + +def create_ctx(local): + from fate.arch import Context + from fate.arch.computing.standalone import CSession + from fate.arch.federation.standalone import StandaloneFederation + import logging + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + computing = CSession() + return Context(computing=computing, + federation=StandaloneFederation(computing, name, local, [guest, host, arbiter])) + + + +if __name__ == "__main__": + + epoch = 10 + model = t.nn.Sequential( + t.nn.Linear(30, 1), + t.nn.Sigmoid() + ) + + ds = TableDataset(return_dict=False, to_tensor=True) + ds.load('../../../../../../examples/data/breast_homo_guest.csv') + + if sys.argv[1] == "guest": + ctx = create_ctx(guest) + fed_args = FedArguments(aggregate_strategy='epochs', aggregate_freq=1, aggregator='secure_aggrefgate') + args = TrainingArguments(num_train_epochs=5, per_device_train_batch_size=16, logging_strategy='steps', logging_steps=5) + trainer = FedAVGCLient(ctx=ctx, model=model, fed_args=fed_args, training_args=args, loss_fn=t.nn.BCELoss(), optimizer=t.optim.SGD(model.parameters(), lr=0.01), train_set=ds) + trainer.train() + + elif sys.argv[1] == "host": + ctx = create_ctx(host) + fed_args = FedArguments(aggregate_strategy='epochs', aggregate_freq=1, aggregator='plaintext') + args = TrainingArguments(num_train_epochs=5, per_device_train_batch_size=16) + trainer = FedAVGCLient(ctx=ctx, model=model, fed_args=fed_args, training_args=args, loss_fn=t.nn.BCELoss(), optimizer=t.optim.SGD(model.parameters(), lr=0.01), train_set=ds) + trainer.train() + + else: + ctx = create_ctx(arbiter) + trainer = FedAVGServer(ctx) + trainer.train() \ No newline at end of file diff --git a/python/fate/ml/nn/trainer/trainer_base.py b/python/fate/ml/nn/trainer/trainer_base.py index 1ef4d2d94a..2117ab4285 100644 --- a/python/fate/ml/nn/trainer/trainer_base.py +++ b/python/fate/ml/nn/trainer/trainer_base.py @@ -23,9 +23,9 @@ from transformers import logging as transformers_logging from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState from typing import Optional -import time from dataclasses import dataclass, field, fields from transformers.trainer_callback import PrinterCallback +from fate.ml.aggregator import AggregatorType # Reset the logger to redirect logs output @@ -34,17 +34,6 @@ logger = logging.getLogger(__name__) -def time_decorator(descr=""): - def decorator(func): - def wrapper(*args, **kwargs): - start_time = time.time() - result = func(*args, **kwargs) - end_time = time.time() - logger.info(f"{descr} takes {end_time - start_time:.2f} seconds.") - return result - return wrapper - return decorator - def get_ith_checkpoint(directory, i): # List all files in the directory @@ -78,8 +67,9 @@ def get_ith_checkpoint(directory, i): class AggregateStrategy(Enum): - EPOCH = "epoch" - STEP = "step" + EPOCH = "epochs" + STEP = "steps" + @dataclass @@ -90,6 +80,7 @@ class FedArguments(object): aggregate_strategy: AggregateStrategy = field( default=AggregateStrategy.EPOCH.value) aggregate_freq: int = field(default=1) + aggregator: str = field(default=AggregatorType.SECURE_AGGREGATE.value) def to_dict(self): """ @@ -287,7 +278,7 @@ def on_federation( **kwargs): pass - def init_aggregator(self): + def init_aggregator(self, fed_arg: FedArguments): raise NotImplementedError( 'init_aggregator() must be implemented in subclass, init aggregator here') @@ -333,7 +324,7 @@ def compute_max_aggregation( elif fed_args.aggregate_strategy == AggregateStrategy.STEP.value: max_aggregation = int((max_steps - steps_trained) / aggregate_freq) else: - raise ValueError('aggregate_strategy must be either "epoch" or "step"') + raise ValueError('aggregate_strategy must be either "epochs" or "steps"') return max_aggregation, aggregate_freq @@ -379,17 +370,17 @@ def should_aggregate(self, state: TrainerState) -> bool: if strategy == AggregateStrategy.EPOCH.value: if cur_epoch > self.epochs_trained and ( cur_epoch - self.epochs_trained) % self.aggregate_freq == 0: - self.aggregation_count += 1 - self.report() return True elif strategy == AggregateStrategy.STEP.value: if cur_step > self.steps_trained and ( cur_step - self.steps_trained) % self.aggregate_freq == 0: - self.aggregation_count += 1 - self.report() return True return False + + def inc_aggregation_count(self): + self.aggregation_count += 1 + self.report() class FedParameterAlignCallback(TrainerCallback): @@ -568,7 +559,8 @@ def __init__(self, ctx: Context, wrapped_trainer: 'StdFedTrainerMixin'): self.wrapped_trainer = wrapped_trainer self.fed_arg = self.wrapped_trainer._fed_args - def _call_wrapped(self, event_name: str, **kwargs): + def _call_wrapped(self, ctx, aggregator, fed_arg, event_name: str, **kwargs): + event = getattr(self.wrapped_trainer, event_name) kwargs['scheduler'] = kwargs.pop('lr_scheduler', None) @@ -577,13 +569,13 @@ def _call_wrapped(self, event_name: str, **kwargs): dataloaders = tuple(filter(None, (train_dataloader, eval_dataloader))) kwargs['dataloader'] = dataloaders return event( - self.ctx, - self.wrapped_trainer.aggregator, - self.fed_arg, + ctx, + aggregator, + fed_arg, **kwargs) -class FedCallbackWrapper(CallbackWrapper): +class WrappedFedCallback(CallbackWrapper): def __init__(self, ctx: Context, wrapped_trainer: 'StdFedTrainerMixin'): super().__init__(ctx, wrapped_trainer) @@ -601,7 +593,7 @@ def on_train_begin( logger.info( 'local mode, skip federation aggregator initialization, aggregator will be None') else: - self.wrapped_trainer.aggregator = self.wrapped_trainer.init_aggregator() + self.wrapped_trainer.aggregator = self.wrapped_trainer.init_aggregator(self.ctx, self.fed_arg) def on_epoch_end( self, @@ -615,12 +607,19 @@ def on_epoch_end( if self.wrapped_trainer.aggregation_checker.should_aggregate( state): logger.info('aggregation on epoch end') - return self._call_wrapped( + agg_round = self.wrapped_trainer.aggregation_checker.aggregation_count + sub_ctx = self.ctx.sub_ctx('aggregation').indexed_ctx(agg_round) + ret = self._call_wrapped( + sub_ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, 'on_federation', args=args, state=state, control=control, **kwargs) + self.wrapped_trainer.aggregation_checker.inc_aggregation_count() + return ret def on_step_end( self, @@ -633,16 +632,25 @@ def on_step_end( if self.fed_arg.aggregate_strategy == AggregateStrategy.STEP.value: if self.wrapped_trainer.aggregation_checker.should_aggregate( state): + + logger.info('state is {}'.format(state)) logger.info('aggregation on step end') - return self._call_wrapped( + agg_round = self.wrapped_trainer.aggregation_checker.aggregation_count + sub_ctx = self.ctx.sub_ctx('aggregation').indexed_ctx(agg_round) + ret = self._call_wrapped( + sub_ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, 'on_federation', args=args, state=state, control=control, **kwargs) + self.wrapped_trainer.aggregation_checker.inc_aggregation_count() + return ret -class ShortcutCallbackWrapper(CallbackWrapper): +class WrappedShortcutCallback(CallbackWrapper): def __init__(self, ctx: Context, wrapped_trainer: 'StdFedTrainerMixin'): super().__init__(ctx, wrapped_trainer) @@ -654,6 +662,9 @@ def on_init_end( control: TrainerControl, **kwargs): return self._call_wrapped( + self.ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, 'on_init_end', args=args, state=state, @@ -667,6 +678,9 @@ def on_train_begin( control: TrainerControl, **kwargs): return self._call_wrapped( + self.ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, 'on_train_begin', args=args, state=state, @@ -680,6 +694,9 @@ def on_train_end( control: TrainerControl, **kwargs): return self._call_wrapped( + self.ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, 'on_train_end', args=args, state=state, @@ -693,6 +710,9 @@ def on_epoch_begin( control: TrainerControl, **kwargs): return self._call_wrapped( + self.ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, 'on_epoch_begin', args=args, state=state, @@ -706,6 +726,9 @@ def on_epoch_end( control: TrainerControl, **kwargs): return self._call_wrapped( + self.ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, 'on_epoch_end', args=args, state=state, @@ -719,6 +742,9 @@ def on_step_begin( control: TrainerControl, **kwargs): return self._call_wrapped( + self.ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, 'on_step_begin', args=args, state=state, @@ -732,6 +758,9 @@ def on_step_end( control: TrainerControl, **kwargs): return self._call_wrapped( + self.ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, 'on_step_end', args=args, state=state, @@ -747,7 +776,7 @@ def on_step_end( """ -class StdFedTrainerMixin(ShortcutCallBackInterFace, FedCallbackInterface): +class StdFedTrainerMixin(FedCallbackInterface, ShortcutCallBackInterFace): def __init__(self, ctx: Context, @@ -763,8 +792,7 @@ def __init__(self, callbacks: Optional[List[TrainerCallback]] = [], use_hf_default_behavior: bool = False, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, - local_mode: bool = False, - parameter_alignment=True + local_mode: bool = False ): assert isinstance( @@ -773,7 +801,6 @@ def __init__(self, self.ctx: Context = ctx self.local_mode = local_mode - self.parameter_alignment = parameter_alignment self._callbacks = callbacks self._args = training_args self._fed_args = fed_args @@ -822,20 +849,16 @@ def _add_fate_callback(self, callback_handler): new_callback_list.append(i) new_callback_list.append(FatePrinterCallback()) callback_handler.callbacks = new_callback_list - callback_handler.callbacks.append(FedCallbackWrapper(self.ctx, self)) - if self.parameter_alignment: - callback_handler.callbacks.append( - FedParameterAlignCallback( - self, - self.ctx, - fed_args=self._fed_args, - training_args=self._args, - is_server=False)) - else: - logger.warning( - 'Parameter alignment is disabled, this may cause fed-training failure') + callback_handler.callbacks.append(WrappedFedCallback(self.ctx, self)) callback_handler.callbacks.append( - ShortcutCallbackWrapper(self.ctx, self)) + FedParameterAlignCallback( + self, + self.ctx, + fed_args=self._fed_args, + training_args=self._args, + is_server=False)) + + callback_handler.callbacks.append(WrappedShortcutCallback(self.ctx, self)) def _remove_fed_callback(self, callback_class): self.callback_handler.callbacks = [ @@ -889,8 +912,7 @@ def __init__(self, callbacks: Optional[List[TrainerCallback]] = [], use_hf_default_behavior: bool = False, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, - local_mode: bool = False, - parameter_alignment=True + local_mode: bool = False ): # in case you forget to set evaluation_strategy @@ -911,8 +933,8 @@ def __init__(self, callbacks=callbacks, use_hf_default_behavior=use_hf_default_behavior, compute_metrics=compute_metrics, - local_mode=local_mode, - parameter_alignment=parameter_alignment) + local_mode=local_mode + ) if data_collator is None: data_collator = _utils.collate.default_collate @@ -939,7 +961,7 @@ def __init__(self, self._add_fate_callback(self.callback_handler) - def init_aggregator(self): + def init_aggregator(self, ctx: Context, fed_arg: FedArguments): return None def compute_loss(self, model, inputs, **kwargs): @@ -990,17 +1012,11 @@ class FedTrainerServer(object): def __init__(self, ctx: Context, - training_args: TrainingArguments = None, - fed_args: FedArguments = None, - parameter_alignment: bool = True, local_mode: bool = False ) -> None: self.ctx = ctx - self.parameter_alignment = parameter_alignment self.local_mode = local_mode - self._args = training_args - self._fed_args = fed_args self._max_steps = None self._parameter_check_callback = FedParameterAlignCallback( self, self.ctx, None, None, is_server=True) @@ -1019,39 +1035,31 @@ def set_fed_mode(self): self.local_mode = False logger.info('trainer set to federated mode') - def init_aggregator(self): + def init_aggregator(self, ctx: Context): return None def on_train_end( self, ctx: Context, - aggregator: Aggregator, - fed_args: FedArguments, - args: TrainingArguments): + aggregator: Aggregator): pass def on_train_begin( self, ctx: Context, - aggregator: Aggregator, - fed_args: FedArguments, - args: TrainingArguments): + aggregator: Aggregator): pass def on_init_end( self, ctx: Context, - aggregator: Aggregator, - fed_args: FedArguments, - args: TrainingArguments): + aggregator: Aggregator): pass def on_federation( self, ctx: Context, - aggregator: Aggregator, - fed_args: FedArguments, - args: TrainingArguments): + aggregator: Aggregator): pass def train(self): @@ -1061,42 +1069,32 @@ def train(self): 'Local model is set, skip initializing fed setting & aggregator') return - self.aggregator: Aggregator = self.init_aggregator() + self.aggregator: Aggregator = self.init_aggregator(self.ctx) logger.info('Initialized aggregator Done: {}'.format(self.aggregator)) - if self.parameter_alignment: - self._parameter_check_callback.on_train_begin( - None, None, None) # only get parameters from clients and align - parameters = self._parameter_check_callback.get_parameters() - self._max_aggregation = parameters['max_aggregation'] - logger.info('checked parameters are {}'.format(parameters)) - else: - logger.warn( - 'If you choose not to use parameter alignment, please make sure that the sever aggregation round matches clients\'') - self._max_aggregation, _ = compute_max_aggregation( - self._fed_args, self._args.num_train_epochs, self._args.max_steps, 0, 0) + self._parameter_check_callback.on_train_begin( + None, None, None) # only get parameters from clients and align + parameters = self._parameter_check_callback.get_parameters() + self._max_aggregation = parameters['max_aggregation'] + logger.info('checked parameters are {}'.format(parameters)) + self.on_init_end( self.ctx, - aggregator=self.aggregator, - args=self._args, - fed_args=self._fed_args) + aggregator=self.aggregator) self.on_train_begin( self.ctx, - aggregator=self.aggregator, - args=self._args, - fed_args=self._fed_args) + aggregator=self.aggregator) + + ctx = self.ctx for i in range(self._max_aggregation): + sub_ctx = ctx.sub_ctx('aggregation').indexed_ctx(i) self.on_federation( - self.ctx, - aggregator=self.aggregator, - args=self._args, - fed_args=self._fed_args) + sub_ctx, + aggregator=self.aggregator) + self.on_train_end( self.ctx, - aggregator=self.aggregator, - args=self._args, - fed_args=self._fed_args) + aggregator=self.aggregator) def predict(self): - # server does not need to predict pass From b551d5178f9ebba3db4135d4b287c69930664da1 Mon Sep 17 00:00:00 2001 From: forgive_dengkai Date: Wed, 19 Jul 2023 16:14:00 +0800 Subject: [PATCH 61/61] osx develop Signed-off-by: forgive_dengkai --- java/osx/README.md | 8 +- java/osx/RELEASE.md | 7 + java/osx/bin/common.sh | 82 ++- java/osx/bin/service.sh | 14 +- .../main/java/com/osx/broker/Bootstrap.java | 165 ------ .../osx/broker/grpc/PushRequestDataWrap.java | 40 -- .../interceptor/RequestHandleInterceptor.java | 107 ---- .../osx/broker/message/MessageDecoder.java | 403 ------------- .../ptp/PtpClusterTopicApplyService.java | 53 -- .../com/osx/broker/ptp/PtpConsumeService.java | 121 ---- .../com/osx/broker/ptp/PtpProduceService.java | 129 ---- .../router/DefaultFateRouterServiceImpl.java | 260 --------- .../java/com/osx/broker/server/OsxServer.java | 193 ------ .../osx/broker/service/UnaryCallService.java | 74 --- .../com/osx/broker/util/TransferUtil.java | 295 ---------- .../src/main/resources/broker.properties | 23 - .../src/main/resources/route_table.json | 31 - .../java/com/osx/core/config/MetaInfo.java | 230 -------- .../java/com/osx/core/router/RouterInfo.java | 60 -- .../java/com/osx/core/utils/FlowLogUtil.java | 54 -- java/osx/deploy/auto-package.sh | 12 +- java/osx/osx-api/pom.xml | 59 ++ .../java/com/osx/api/constants}/Protocol.java | 5 +- .../java/com/osx/api/context/Context.java | 45 ++ .../java/com/osx/api/router/RouterInfo.java | 193 ++++++ .../osx/api/tech}/provider/TechProvider.java | 15 +- .../com/osx/api/translator/Translator.java | 16 + java/osx/{broker => osx-broker}/package.xml | 30 +- java/osx/{broker => osx-broker}/pom.xml | 42 +- .../main/java/com/osx/broker/Bootstrap.java | 91 +++ .../java/com/osx/broker/ServiceContainer.java | 128 ++-- .../com/osx/broker/buffer/BufferStatus.java | 0 .../com/osx/broker/buffer/ReadResult.java | 0 .../com/osx/broker/buffer/ReadStatus.java | 0 .../osx/broker/buffer/TransferBufferUtil.java | 0 .../com/osx/broker/buffer/WriteResult.java | 0 .../com/osx/broker/buffer/WriteStatus.java | 0 .../osx/broker/callback/CompleteCallback.java | 0 .../broker/callback/CreateUserCallback.java | 37 ++ .../osx/broker/callback/DestoryCallback.java | 0 .../osx/broker/callback/ErrorCallback.java | 0 .../callback/MockDesGrpcEventHandler.java | 78 +++ .../osx/broker/callback/MsgEventCallback.java | 10 + .../callback/MsgEventDispatchCallback.java | 29 + .../com/osx/broker/constants/Direction.java | 0 .../com/osx/broker/constants/MessageFlag.java | 6 +- .../osx/broker/consumer/ConsumerManager.java | 106 ++-- .../broker/consumer/EventDrivenConsumer.java | 52 ++ .../osx/broker/consumer/EventDriverRule.java | 7 + .../osx/broker/consumer/GrpcEventHandler.java | 150 +++++ .../broker/consumer/LocalQueueConsumer.java | 3 +- .../com/osx/broker/consumer/MessageEvent.java | 13 + .../osx/broker/consumer/RedirectConsumer.java | 3 +- .../consumer/SourceGrpcEventHandler.java | 65 +++ .../osx/broker/consumer/StreamConsumer.java | 0 .../osx/broker/consumer/UnaryConsumer.java | 51 +- .../com/osx/broker/eggroll/BaseProto.java | 0 .../broker/eggroll/ClusterManagerClient.java | 4 +- .../com/osx/broker/eggroll/CommandClient.java | 0 .../com/osx/broker/eggroll/CommandURI.java | 0 .../com/osx/broker/eggroll/ErEndpoint.java | 0 .../com/osx/broker/eggroll/ErFunctor.java | 0 .../java/com/osx/broker/eggroll/ErJob.java | 0 .../com/osx/broker/eggroll/ErPartition.java | 0 .../com/osx/broker/eggroll/ErProcessor.java | 0 .../osx/broker/eggroll/ErRollSiteHeader.java | 0 .../com/osx/broker/eggroll/ErSession.java | 0 .../com/osx/broker/eggroll/ErSessionMeta.java | 0 .../java/com/osx/broker/eggroll/ErStore.java | 0 .../osx/broker/eggroll/ErStoreLocator.java | 0 .../java/com/osx/broker/eggroll/ErTask.java | 0 .../broker/eggroll/EventDriverMsgManager.java | 69 +++ .../java/com/osx/broker/eggroll/IdUtils.java | 0 .../com/osx/broker/eggroll/MetaCommnads.java | 0 .../osx/broker/eggroll/PartitionerTypes.java | 0 .../osx/broker/eggroll/PushEventHandler.java | 291 +++++++++ .../eggroll/PutBatchSinkPushRespSO.java | 0 .../osx/broker/eggroll/PutBatchSinkUtil.java | 14 - .../java/com/osx/broker/eggroll/RollPair.java | 0 .../osx/broker/eggroll/RollPairContext.java | 0 .../com/osx/broker/eggroll/SerdesTypes.java | 0 .../osx/broker/eggroll/SessionCommands.java | 0 .../osx/broker/eggroll/SessionConfKeys.java | 0 .../com/osx/broker/eggroll/SessionStatus.java | 0 .../broker/flow/ClusterMetricStatistics.java | 0 .../com/osx/broker/flow/ClusterRuleUtil.java | 0 .../grpc/ContextPrepareInterceptor.java | 0 .../osx/broker/grpc/ForwardPullRespSO.java | 2 +- .../osx/broker/grpc/ForwardPushRespSO.java | 2 +- .../com/osx/broker/grpc/PcpGrpcService.java | 0 .../com/osx/broker/grpc/ProxyGrpcService.java | 26 +- .../osx/broker/grpc/PullRequestDataWrap.java | 0 .../osx/broker/grpc/PushRequestDataWrap.java | 40 ++ .../grpc/QueuePushReqStreamObserver.java | 156 ++--- .../osx/broker/grpc/QueueStreamBuilder.java | 125 ++++ .../broker/grpc/ServiceExceptionHandler.java | 0 .../com/osx/broker/http/DispatchServlet.java | 47 +- .../com/osx/broker/http/HttpClientPool.java | 62 +- .../com/osx/broker/http/HttpsClientPool.java | 205 +++++++ .../com/osx/broker/http/PtpHttpResponse.java | 0 .../interceptor/PcpHandleInterceptor.java | 36 ++ .../interceptor/PushHandleInterceptor.java | 19 + .../broker/interceptor/RouterInterceptor.java | 29 +- .../TokenValidatorInterceptor.java | 28 + .../UnaryCallHandleInterceptor.java | 19 + .../message/AllocateMappedFileService.java | 0 .../broker/message/AppendMessageHandler.java | 0 .../broker/message/AppendMessageResult.java | 0 .../broker/message/AppendMessageStatus.java | 0 .../message/DefaultAppendMessageHandler.java | 83 +-- .../java/com/osx/broker/message/Message.java | 0 .../osx/broker/message/MessageDecoder.java | 410 +++++++++++++ .../com/osx/broker/message/MessageExt.java | 7 - .../broker/message/MessageExtBrokerInner.java | 11 +- .../broker/message/MessageStoreConfig.java | 0 .../osx/broker/message/MessageSysFlag.java | 0 .../com/osx/broker/message/MessageWraper.java | 0 .../message/SelectMappedBufferResult.java | 0 .../broker/metric/ClusterMetricLeapArray.java | 0 .../broker/ptp/AbstractPtpServiceAdaptor.java | 8 +- .../com/osx/broker/ptp/PtpAckService.java | 20 +- .../broker/ptp/PtpCancelTransferService.java | 9 +- .../ptp/PtpClusterTokenApplyService.java | 6 +- .../ptp/PtpClusterTopicApplyService.java | 60 ++ .../com/osx/broker/ptp/PtpConsumeService.java | 135 +++++ .../osx/broker/ptp/PtpForwardPushRespSO.java | 3 +- .../com/osx/broker/ptp/PtpProduceService.java | 208 +++++++ .../com/osx/broker/ptp/PtpPushService.java | 16 +- .../ptp/PtpQueryTransferQueueService.java | 4 +- .../osx/broker/ptp/PtpStreamTestService.java | 99 ++++ .../osx/broker/ptp/PtpUnaryCallService.java | 26 +- .../java/com/osx/broker/queue/Consumer.java | 4 +- .../osx/broker/queue/CreateQueueResult.java | 0 .../java/com/osx/broker/queue/MappedFile.java | 0 .../com/osx/broker/queue/MappedFileQueue.java | 0 .../com/osx/broker/queue/PutMessageLock.java | 0 .../broker/queue/PutMessageReentrantLock.java | 0 .../osx/broker/queue/PutMessageResult.java | 0 .../osx/broker/queue/PutMessageStatus.java | 0 .../osx/broker/queue/ReferenceResource.java | 0 .../com/osx/broker/queue/TransferQueue.java | 103 +++- .../broker/queue/TransferQueueApplyInfo.java | 0 .../broker/queue/TransferQueueManager.java | 301 ++++++---- .../queue/TransferQueueMonitorService.java | 0 .../router/DefaultFateRouterServiceImpl.java | 445 ++++++++++++++ .../osx/broker/router/FateRouterService.java | 7 +- .../broker/router/RemoteRouterDataSource.java | 0 .../com/osx/broker/router/RouterMetric.java | 0 .../com/osx/broker/router/RouterRegister.java | 70 +++ .../com/osx/broker/router/RouterService.java | 8 + .../broker/security/MockTokenGenerator.java | 13 + .../osx/broker/security/TokenGenerator.java | 10 + .../security/TokenGeneratorRegister.java | 73 +++ .../osx/broker/security/TokenValidator.java | 8 + .../security/TokenValidatorRegister.java | 76 +++ .../java/com/osx/broker/server/OsxServer.java | 311 ++++++++++ .../com/osx/broker/service/PushService.java | 19 +- .../osx/broker/service/RegisterService.java | 0 .../com/osx/broker/service/RouteService.java | 28 + .../osx/broker/service/TokenApplyService.java | 13 +- .../osx/broker/service/UnaryCallService.java | 105 ++++ .../java/com/osx/broker/store/IndexQueue.java | 16 +- .../com/osx/broker/store/MessageStore.java | 0 .../osx/broker/token/DefaultTokenService.java | 5 +- .../java/com/osx/broker/util/ContextUtil.java | 14 +- .../java/com/osx/broker/util/DateUtils.java | 0 .../main/java/com/osx/broker/util/LibC.java | 0 .../com/osx/broker/util/MessageConst.java | 0 .../java/com/osx/broker/util/MessageId.java | 0 .../com/osx/broker/util/ResourceUtil.java | 3 +- .../java/com/osx/broker/util/TelnetUtil.java | 21 + .../java/com/osx/broker/util/TimeUtils.java | 0 .../broker/util/TransferExceptionUtil.java | 0 .../com/osx/broker/util/TransferUtil.java | 550 ++++++++++++++++++ .../java/com/osx/broker/util/UtilAll.java | 3 - .../broker/zk/AbstractZookeeperClient.java | 0 .../java/com/osx/broker/zk/ChildListener.java | 0 .../osx/broker/zk/CuratorZookeeperClient.java | 5 - .../java/com/osx/broker/zk/DataListener.java | 0 .../java/com/osx/broker/zk/EventType.java | 0 .../java/com/osx/broker/zk/StateListener.java | 0 .../main/java/com/osx/broker/zk/ZkConfig.java | 0 .../com/osx/broker/zk/ZookeeperClient.java | 0 .../osx/tech/provider/FateTechProvider.java | 187 +++--- .../tech/provider/TechProviderRegister.java | 34 +- .../main/resources/broker/broker.properties | 56 ++ .../src/main/resources/broker}/flowRule.json | 0 .../main/resources/broker/route_table.json | 44 ++ .../resources/components/provider.properties | 2 + .../resources/components/router.properties | 1 + .../components/translator.properties | 2 + .../src/main/resources/log4j2.xml | 0 .../cluster/ClusterClientEndpointTest.java | 0 .../com/osx/broker/mock/MockHttpServer.java | 27 + .../java/com/osx/broker/mock/MockServer.java | 57 +- .../com/osx/broker/test/grpc/EggrollTest.java | 0 .../com/osx/broker/test/grpc/Grpc_UC.java | 55 ++ .../com/osx/broker/test/grpc/NewFateTest.java | 79 ++- .../com/osx/broker/test/grpc/OldFateTest.java | 22 +- .../com/osx/broker/test/grpc/QueueTest.java | 124 ++-- .../osx/broker/test/grpc/SyncQueueTest.java | 0 .../osx/broker/test/grpc/UnaryCallTest.java | 6 + .../com/osx/broker/test/http/HttpTest.java | 24 +- .../broker/test/http/Http_PRODUCE_MSG.java | 92 +++ .../osx/broker/test/utils/JsonToMapCode.java | 28 + java/osx/{core => osx-core}/pom.xml | 8 +- .../main/java/com/osx/core/config/Config.java | 20 + .../com/osx/core/config/GrpcChannelInfo.java | 0 .../java/com/osx/core/config/MasterInfo.java | 0 .../java/com/osx/core/config/MetaInfo.java | 338 +++++++++++ .../com/osx/core/config/TransferMeta.java | 17 + .../com/osx/core/constant/ActionType.java | 1 + .../com/osx/core/constant/DeployMode.java | 0 .../main/java/com/osx/core/constant/Dict.java | 173 ++---- .../com/osx/core/constant/EncryptMethod.java | 0 .../osx/core/constant/NegotiationType.java | 0 .../com/osx/core/constant/PtpHttpHeader.java | 5 + .../main/java/com/osx/core/constant/Role.java | 5 + .../com/osx/core/constant/StatusCode.java | 6 + .../osx/core/constant/StreamLimitMode.java | 0 .../com/osx/core/constant/TransferStatus.java | 0 .../com/osx/core/context/FateContext.java} | 228 +++++--- .../core/datasource/AbstractDataSource.java | 0 .../datasource/AutoRefreshDataSource.java | 0 .../com/osx/core/datasource/Converter.java | 0 .../datasource/FileRefreshableDataSource.java | 0 .../core/datasource/NamedThreadFactory.java | 0 .../core/datasource/ReadableDataSource.java | 0 .../core/exceptions/AckIndexException.java | 0 .../osx/core/exceptions/BaseException.java | 0 .../core/exceptions/ConfigErrorException.java | 0 .../exceptions/ConsumeNoMessageException.java | 0 .../exceptions/ConsumerNotExistException.java | 0 .../exceptions/CreateTopicErrorException.java | 0 .../exceptions/CycleRouteInfoException.java | 9 + .../osx/core/exceptions/ErrorMessageUtil.java | 52 +- .../osx/core/exceptions/ExceptionInfo.java | 0 .../InvalidRedirectInfoException.java | 0 .../exceptions/InvalidRouteInfoException.java | 8 + .../core/exceptions/MappedFileException.java | 0 .../exceptions/MessageParseException.java | 0 .../exceptions/NoRouterInfoException.java | 0 .../core/exceptions/ParameterException.java | 0 .../core/exceptions/ProduceMsgExcption.java | 0 .../core/exceptions/PutMessageException.java | 0 .../core/exceptions/RemoteRpcException.java | 0 .../RouterInfoOperateException.java | 0 .../core/exceptions/SessionInitException.java | 12 + .../com/osx/core/exceptions/SysException.java | 0 .../TransferQueueAlreadyExistException.java | 0 .../TransferQueueInvalidStatusException.java | 0 .../TransferQueueNotExistException.java | 4 +- .../exceptions/UnSupportMethodException.java | 0 .../java/com/osx/core/flow/AbstractRule.java | 0 .../com/osx/core/flow/BucketLeapArray.java | 0 .../com/osx/core/flow/ClusterFlowChecker.java | 0 .../com/osx/core/flow/ClusterFlowConfig.java | 0 .../com/osx/core/flow/ClusterFlowEvent.java | 0 .../osx/core/flow/ClusterFlowRuleManager.java | 49 +- .../java/com/osx/core/flow/ClusterMetric.java | 0 .../osx/core/flow/ClusterMetricBucket.java | 0 .../osx/core/flow/ClusterMetricLeapArray.java | 0 .../core/flow/ClusterMetricStatistics.java | 0 .../osx/core/flow/ClusterRuleConstant.java | 0 .../com/osx/core/flow/ClusterRuleUtil.java | 0 .../core/flow/CurrentConcurrencyManager.java | 0 .../java/com/osx/core/flow/DebugSupport.java | 0 .../com/osx/core/flow/DynamicProperty.java | 0 .../com/osx/core/flow/FileMetricReport.java | 2 +- .../java/com/osx/core/flow/FlowCounter.java | 0 .../com/osx/core/flow/FlowCounterManager.java | 0 .../main/java/com/osx/core/flow/FlowRule.java | 0 .../main/java/com/osx/core/flow/Function.java | 0 .../osx/core/flow/GlobalRequestLimiter.java | 0 .../java/com/osx/core/flow/LeapArray.java | 0 .../java/com/osx/core/flow/LimitQueue.java | 0 .../main/java/com/osx/core/flow/Metric.java | 0 .../java/com/osx/core/flow/MetricBucket.java | 0 .../java/com/osx/core/flow/MetricEvent.java | 0 .../java/com/osx/core/flow/MetricNode.java | 0 .../java/com/osx/core/flow/MetricReport.java | 0 .../com/osx/core/flow/MetricSearcher.java | 0 .../java/com/osx/core/flow/MetricWriter.java | 0 .../java/com/osx/core/flow/MetricsReader.java | 0 .../osx/core/flow/NamespaceFlowProperty.java | 0 .../java/com/osx/core/flow/OccupySupport.java | 0 .../main/java/com/osx/core/flow/Property.java | 0 .../com/osx/core/flow/PropertyListener.java | 0 .../com/osx/core/flow/RequestLimiter.java | 0 .../src/main/java/com/osx/core/flow/Rule.java | 0 .../java/com/osx/core/flow/RuleConstant.java | 0 .../main/java/com/osx/core/flow/TimeUtil.java | 0 .../java/com/osx/core/flow/TokenService.java | 0 .../com/osx/core/flow/UnaryLeapArray.java | 0 .../java/com/osx/core/flow/WindowWrap.java | 0 .../com/osx/core/frame/CountDownLatch.java | 0 .../osx/core/frame/GrpcConnectionFactory.java | 21 +- .../java/com/osx/core/frame/Lifecycle.java | 0 .../osx/core/frame/ServiceDataWrapper.java | 0 .../com/osx/core/frame/ServiceThread.java | 0 .../java/com/osx/core/jvm/JVMGCUtils.java | 0 .../java/com/osx/core/jvm/JVMMemoryUtils.java | 0 .../java/com/osx/core/jvm/JVMThreadUtils.java | 15 +- .../main/java/com/osx/core/jvm/JvmInfo.java | 0 .../java/com/osx/core/jvm/JvmInfoCounter.java | 12 +- .../com/osx/core/jvm/JvmInfoLeapArray.java | 0 .../com/osx/core/provider/TechProvider.java | 53 ++ .../java/com/osx/core/ptp/SourceMethod.java | 6 + .../java/com/osx/core/ptp/TargetMethod.java | 3 +- .../core/queue/ClusterTransferQueueInfo.java | 0 .../com/osx/core/queue/TranferQueueInfo.java | 0 .../core/service/AbstractServiceAdaptor.java | 69 +-- .../core/service/DefaultInterceptorChain.java | 19 +- .../com/osx/core/service/InboundPackage.java | 12 +- .../com/osx/core/service/Interceptor.java | 7 +- .../osx/core/service/InterceptorChain.java | 6 +- .../com/osx/core/service/OutboundPackage.java | 0 .../com/osx/core/service/ServiceAdaptor.java | 10 +- .../com/osx/core/timer/HashedWheelTimer.java | 0 .../main/java/com/osx/core/timer/Timeout.java | 0 .../main/java/com/osx/core/timer/Timer.java | 0 .../java/com/osx/core/timer/TimerTask.java | 0 .../java/com/osx/core/token/TokenRequest.java | 0 .../java/com/osx/core/token/TokenResult.java | 0 .../com/osx/core/token/TokenResultStatus.java | 0 .../java/com/osx/core/utils/AssertUtil.java | 0 .../java/com/osx/core/utils/ClassUtils.java | 392 +++++++++++++ .../java/com/osx/core/utils/EncryptUtils.java | 4 +- .../java/com/osx/core/utils/FileUtils.java | 137 +++++ .../com/osx/core/utils/FlowLogPrinter.java | 16 +- .../java/com/osx/core/utils/FlowLogUtil.java | 22 + .../com/osx/core/utils/GetSystemInfo.java | 0 .../java/com/osx/core/utils/JsonUtil.java | 11 +- .../java/com/osx/core/utils/NetUtils.java | 0 .../java/com/osx/core/utils/OSXCertUtils.java | 147 +++++ .../osx/core/utils/OsxX509TrustManager.java | 247 ++++++++ .../com/osx/core/utils/PropertiesUtil.java | 131 +++++ .../java/com/osx/core/utils/RouterUtil.java | 0 .../java/com/osx/core/utils/ServerUtil.java | 0 .../com/osx/core/utils/ThreadPoolUtil.java | 0 .../com/osx/core/utils/ToStringUtils.java | 0 java/osx/pom.xml | 60 +- java/osx/proto/osx.proto | 50 +- .../fate/arch/federation/osx/_mq_channel.py | 18 +- 344 files changed, 7713 insertions(+), 3617 deletions(-) create mode 100644 java/osx/RELEASE.md delete mode 100644 java/osx/broker/src/main/java/com/osx/broker/Bootstrap.java delete mode 100644 java/osx/broker/src/main/java/com/osx/broker/grpc/PushRequestDataWrap.java delete mode 100644 java/osx/broker/src/main/java/com/osx/broker/interceptor/RequestHandleInterceptor.java delete mode 100644 java/osx/broker/src/main/java/com/osx/broker/message/MessageDecoder.java delete mode 100644 java/osx/broker/src/main/java/com/osx/broker/ptp/PtpClusterTopicApplyService.java delete mode 100644 java/osx/broker/src/main/java/com/osx/broker/ptp/PtpConsumeService.java delete mode 100644 java/osx/broker/src/main/java/com/osx/broker/ptp/PtpProduceService.java delete mode 100644 java/osx/broker/src/main/java/com/osx/broker/router/DefaultFateRouterServiceImpl.java delete mode 100644 java/osx/broker/src/main/java/com/osx/broker/server/OsxServer.java delete mode 100644 java/osx/broker/src/main/java/com/osx/broker/service/UnaryCallService.java delete mode 100644 java/osx/broker/src/main/java/com/osx/broker/util/TransferUtil.java delete mode 100644 java/osx/broker/src/main/resources/broker.properties delete mode 100755 java/osx/broker/src/main/resources/route_table.json delete mode 100644 java/osx/core/src/main/java/com/osx/core/config/MetaInfo.java delete mode 100644 java/osx/core/src/main/java/com/osx/core/router/RouterInfo.java delete mode 100644 java/osx/core/src/main/java/com/osx/core/utils/FlowLogUtil.java create mode 100644 java/osx/osx-api/pom.xml rename java/osx/{core/src/main/java/com/osx/core/constant => osx-api/src/main/java/com/osx/api/constants}/Protocol.java (92%) create mode 100644 java/osx/osx-api/src/main/java/com/osx/api/context/Context.java create mode 100644 java/osx/osx-api/src/main/java/com/osx/api/router/RouterInfo.java rename java/osx/{core/src/main/java/com/osx/core => osx-api/src/main/java/com/osx/api/tech}/provider/TechProvider.java (84%) create mode 100644 java/osx/osx-api/src/main/java/com/osx/api/translator/Translator.java rename java/osx/{broker => osx-broker}/package.xml (73%) rename java/osx/{broker => osx-broker}/pom.xml (84%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/Bootstrap.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/ServiceContainer.java (62%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/buffer/BufferStatus.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/buffer/ReadResult.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/buffer/ReadStatus.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/buffer/TransferBufferUtil.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/buffer/WriteResult.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/buffer/WriteStatus.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/callback/CompleteCallback.java (100%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/callback/CreateUserCallback.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/callback/DestoryCallback.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/callback/ErrorCallback.java (100%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/callback/MockDesGrpcEventHandler.java create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/callback/MsgEventCallback.java create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/callback/MsgEventDispatchCallback.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/constants/Direction.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/constants/MessageFlag.java (89%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/consumer/ConsumerManager.java (67%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/consumer/EventDrivenConsumer.java create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/consumer/EventDriverRule.java create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/consumer/GrpcEventHandler.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/consumer/LocalQueueConsumer.java (96%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/consumer/MessageEvent.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/consumer/RedirectConsumer.java (97%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/consumer/SourceGrpcEventHandler.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/consumer/StreamConsumer.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/consumer/UnaryConsumer.java (74%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/BaseProto.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/ClusterManagerClient.java (97%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/CommandClient.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/CommandURI.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/ErEndpoint.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/ErFunctor.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/ErJob.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/ErPartition.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/ErProcessor.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/ErRollSiteHeader.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/ErSession.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/ErSessionMeta.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/ErStore.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/ErStoreLocator.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/ErTask.java (100%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/EventDriverMsgManager.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/IdUtils.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/MetaCommnads.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/PartitionerTypes.java (100%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/PushEventHandler.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/PutBatchSinkPushRespSO.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/PutBatchSinkUtil.java (73%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/RollPair.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/RollPairContext.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/SerdesTypes.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/SessionCommands.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/SessionConfKeys.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/eggroll/SessionStatus.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/flow/ClusterMetricStatistics.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/flow/ClusterRuleUtil.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/grpc/ContextPrepareInterceptor.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/grpc/ForwardPullRespSO.java (98%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/grpc/ForwardPushRespSO.java (98%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/grpc/PcpGrpcService.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/grpc/ProxyGrpcService.java (78%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/grpc/PullRequestDataWrap.java (100%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/grpc/PushRequestDataWrap.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/grpc/QueuePushReqStreamObserver.java (72%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/grpc/QueueStreamBuilder.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/grpc/ServiceExceptionHandler.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/http/DispatchServlet.java (69%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/http/HttpClientPool.java (76%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/http/HttpsClientPool.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/http/PtpHttpResponse.java (100%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/PcpHandleInterceptor.java create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/PushHandleInterceptor.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/interceptor/RouterInterceptor.java (64%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/TokenValidatorInterceptor.java create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/UnaryCallHandleInterceptor.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/message/AllocateMappedFileService.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/message/AppendMessageHandler.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/message/AppendMessageResult.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/message/AppendMessageStatus.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/message/DefaultAppendMessageHandler.java (65%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/message/Message.java (100%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageDecoder.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/message/MessageExt.java (99%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/message/MessageExtBrokerInner.java (85%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/message/MessageStoreConfig.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/message/MessageSysFlag.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/message/MessageWraper.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/message/SelectMappedBufferResult.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/metric/ClusterMetricLeapArray.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/ptp/AbstractPtpServiceAdaptor.java (84%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/ptp/PtpAckService.java (91%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/ptp/PtpCancelTransferService.java (89%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/ptp/PtpClusterTokenApplyService.java (90%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpClusterTopicApplyService.java create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpConsumeService.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/ptp/PtpForwardPushRespSO.java (98%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpProduceService.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/ptp/PtpPushService.java (78%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/ptp/PtpQueryTransferQueueService.java (98%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpStreamTestService.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/ptp/PtpUnaryCallService.java (54%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/queue/Consumer.java (95%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/queue/CreateQueueResult.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/queue/MappedFile.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/queue/MappedFileQueue.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/queue/PutMessageLock.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/queue/PutMessageReentrantLock.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/queue/PutMessageResult.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/queue/PutMessageStatus.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/queue/ReferenceResource.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/queue/TransferQueue.java (74%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/queue/TransferQueueApplyInfo.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/queue/TransferQueueManager.java (73%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/queue/TransferQueueMonitorService.java (100%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/router/DefaultFateRouterServiceImpl.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/router/FateRouterService.java (81%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/router/RemoteRouterDataSource.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/router/RouterMetric.java (100%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/router/RouterRegister.java create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/router/RouterService.java create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/security/MockTokenGenerator.java create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenGenerator.java create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenGeneratorRegister.java create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenValidator.java create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenValidatorRegister.java create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/server/OsxServer.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/service/PushService.java (70%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/service/RegisterService.java (100%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/service/RouteService.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/service/TokenApplyService.java (97%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/service/UnaryCallService.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/store/IndexQueue.java (95%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/store/MessageStore.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/token/DefaultTokenService.java (94%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/util/ContextUtil.java (79%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/util/DateUtils.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/util/LibC.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/util/MessageConst.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/util/MessageId.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/util/ResourceUtil.java (96%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/util/TelnetUtil.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/util/TimeUtils.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/util/TransferExceptionUtil.java (100%) create mode 100644 java/osx/osx-broker/src/main/java/com/osx/broker/util/TransferUtil.java rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/util/UtilAll.java (99%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/zk/AbstractZookeeperClient.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/zk/ChildListener.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/zk/CuratorZookeeperClient.java (98%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/zk/DataListener.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/zk/EventType.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/zk/StateListener.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/zk/ZkConfig.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/broker/zk/ZookeeperClient.java (100%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/tech/provider/FateTechProvider.java (53%) rename java/osx/{broker => osx-broker}/src/main/java/com/osx/tech/provider/TechProviderRegister.java (51%) create mode 100644 java/osx/osx-broker/src/main/resources/broker/broker.properties rename java/osx/{broker/src/main/resources => osx-broker/src/main/resources/broker}/flowRule.json (100%) create mode 100644 java/osx/osx-broker/src/main/resources/broker/route_table.json create mode 100644 java/osx/osx-broker/src/main/resources/components/provider.properties create mode 100644 java/osx/osx-broker/src/main/resources/components/router.properties create mode 100644 java/osx/osx-broker/src/main/resources/components/translator.properties rename java/osx/{broker => osx-broker}/src/main/resources/log4j2.xml (100%) rename java/osx/{broker => osx-broker}/src/test/java/com/osx/broker/cluster/ClusterClientEndpointTest.java (100%) create mode 100644 java/osx/osx-broker/src/test/java/com/osx/broker/mock/MockHttpServer.java rename java/osx/{broker => osx-broker}/src/test/java/com/osx/broker/mock/MockServer.java (73%) rename java/osx/{broker => osx-broker}/src/test/java/com/osx/broker/test/grpc/EggrollTest.java (100%) create mode 100644 java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/Grpc_UC.java rename java/osx/{broker => osx-broker}/src/test/java/com/osx/broker/test/grpc/NewFateTest.java (52%) rename java/osx/{broker => osx-broker}/src/test/java/com/osx/broker/test/grpc/OldFateTest.java (89%) rename java/osx/{broker => osx-broker}/src/test/java/com/osx/broker/test/grpc/QueueTest.java (72%) rename java/osx/{broker => osx-broker}/src/test/java/com/osx/broker/test/grpc/SyncQueueTest.java (100%) create mode 100644 java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/UnaryCallTest.java rename java/osx/{broker => osx-broker}/src/test/java/com/osx/broker/test/http/HttpTest.java (76%) create mode 100644 java/osx/osx-broker/src/test/java/com/osx/broker/test/http/Http_PRODUCE_MSG.java create mode 100644 java/osx/osx-broker/src/test/java/com/osx/broker/test/utils/JsonToMapCode.java rename java/osx/{core => osx-core}/pom.xml (93%) create mode 100644 java/osx/osx-core/src/main/java/com/osx/core/config/Config.java rename java/osx/{core => osx-core}/src/main/java/com/osx/core/config/GrpcChannelInfo.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/config/MasterInfo.java (100%) create mode 100644 java/osx/osx-core/src/main/java/com/osx/core/config/MetaInfo.java create mode 100644 java/osx/osx-core/src/main/java/com/osx/core/config/TransferMeta.java rename java/osx/{core => osx-core}/src/main/java/com/osx/core/constant/ActionType.java (97%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/constant/DeployMode.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/constant/Dict.java (55%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/constant/EncryptMethod.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/constant/NegotiationType.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/constant/PtpHttpHeader.java (94%) create mode 100644 java/osx/osx-core/src/main/java/com/osx/core/constant/Role.java rename java/osx/{core => osx-core}/src/main/java/com/osx/core/constant/StatusCode.java (88%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/constant/StreamLimitMode.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/constant/TransferStatus.java (100%) rename java/osx/{core/src/main/java/com/osx/core/context/Context.java => osx-core/src/main/java/com/osx/core/context/FateContext.java} (57%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/datasource/AbstractDataSource.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/datasource/AutoRefreshDataSource.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/datasource/Converter.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/datasource/FileRefreshableDataSource.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/datasource/NamedThreadFactory.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/datasource/ReadableDataSource.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/AckIndexException.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/BaseException.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/ConfigErrorException.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/ConsumeNoMessageException.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/ConsumerNotExistException.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/CreateTopicErrorException.java (100%) create mode 100644 java/osx/osx-core/src/main/java/com/osx/core/exceptions/CycleRouteInfoException.java rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/ErrorMessageUtil.java (64%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/ExceptionInfo.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/InvalidRedirectInfoException.java (100%) create mode 100644 java/osx/osx-core/src/main/java/com/osx/core/exceptions/InvalidRouteInfoException.java rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/MappedFileException.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/MessageParseException.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/NoRouterInfoException.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/ParameterException.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/ProduceMsgExcption.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/PutMessageException.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/RemoteRpcException.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/RouterInfoOperateException.java (100%) create mode 100644 java/osx/osx-core/src/main/java/com/osx/core/exceptions/SessionInitException.java rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/SysException.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/TransferQueueAlreadyExistException.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/TransferQueueInvalidStatusException.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/TransferQueueNotExistException.java (88%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/exceptions/UnSupportMethodException.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/AbstractRule.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/BucketLeapArray.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/ClusterFlowChecker.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/ClusterFlowConfig.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/ClusterFlowEvent.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/ClusterFlowRuleManager.java (91%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/ClusterMetric.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/ClusterMetricBucket.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/ClusterMetricLeapArray.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/ClusterMetricStatistics.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/ClusterRuleConstant.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/ClusterRuleUtil.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/CurrentConcurrencyManager.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/DebugSupport.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/DynamicProperty.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/FileMetricReport.java (97%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/FlowCounter.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/FlowCounterManager.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/FlowRule.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/Function.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/GlobalRequestLimiter.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/LeapArray.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/LimitQueue.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/Metric.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/MetricBucket.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/MetricEvent.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/MetricNode.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/MetricReport.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/MetricSearcher.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/MetricWriter.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/MetricsReader.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/NamespaceFlowProperty.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/OccupySupport.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/Property.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/PropertyListener.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/RequestLimiter.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/Rule.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/RuleConstant.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/TimeUtil.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/TokenService.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/UnaryLeapArray.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/flow/WindowWrap.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/frame/CountDownLatch.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/frame/GrpcConnectionFactory.java (84%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/frame/Lifecycle.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/frame/ServiceDataWrapper.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/frame/ServiceThread.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/jvm/JVMGCUtils.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/jvm/JVMMemoryUtils.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/jvm/JVMThreadUtils.java (78%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/jvm/JvmInfo.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/jvm/JvmInfoCounter.java (89%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/jvm/JvmInfoLeapArray.java (100%) create mode 100644 java/osx/osx-core/src/main/java/com/osx/core/provider/TechProvider.java create mode 100644 java/osx/osx-core/src/main/java/com/osx/core/ptp/SourceMethod.java rename java/osx/{core => osx-core}/src/main/java/com/osx/core/ptp/TargetMethod.java (96%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/queue/ClusterTransferQueueInfo.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/queue/TranferQueueInfo.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/service/AbstractServiceAdaptor.java (70%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/service/DefaultInterceptorChain.java (65%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/service/InboundPackage.java (86%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/service/Interceptor.java (73%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/service/InterceptorChain.java (75%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/service/OutboundPackage.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/service/ServiceAdaptor.java (70%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/timer/HashedWheelTimer.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/timer/Timeout.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/timer/Timer.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/timer/TimerTask.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/token/TokenRequest.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/token/TokenResult.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/token/TokenResultStatus.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/utils/AssertUtil.java (100%) create mode 100644 java/osx/osx-core/src/main/java/com/osx/core/utils/ClassUtils.java rename java/osx/{core => osx-core}/src/main/java/com/osx/core/utils/EncryptUtils.java (97%) create mode 100644 java/osx/osx-core/src/main/java/com/osx/core/utils/FileUtils.java rename java/osx/{core => osx-core}/src/main/java/com/osx/core/utils/FlowLogPrinter.java (81%) create mode 100644 java/osx/osx-core/src/main/java/com/osx/core/utils/FlowLogUtil.java rename java/osx/{core => osx-core}/src/main/java/com/osx/core/utils/GetSystemInfo.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/utils/JsonUtil.java (94%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/utils/NetUtils.java (100%) create mode 100644 java/osx/osx-core/src/main/java/com/osx/core/utils/OSXCertUtils.java create mode 100644 java/osx/osx-core/src/main/java/com/osx/core/utils/OsxX509TrustManager.java create mode 100644 java/osx/osx-core/src/main/java/com/osx/core/utils/PropertiesUtil.java rename java/osx/{core => osx-core}/src/main/java/com/osx/core/utils/RouterUtil.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/utils/ServerUtil.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/utils/ThreadPoolUtil.java (100%) rename java/osx/{core => osx-core}/src/main/java/com/osx/core/utils/ToStringUtils.java (100%) diff --git a/java/osx/README.md b/java/osx/README.md index 964392f623..055f994d4e 100644 --- a/java/osx/README.md +++ b/java/osx/README.md @@ -1,7 +1 @@ -# Release 1.0.0-alpha -## Major Features and Improvements -* Support grpc synchronous transmission and streaming transmission. Compatible with eggroll interface and can replace FATE1. x rollsite component -* Support asynchronous message transmission, which can replace rabbitmq and pulsar components in FATE1. x -* Support HTTP1. X protocol transmission -* Support cluster deployment and inter-site traffic control -* Support networking as an Exchange component \ No newline at end of file +OSX: Open Site Exchange \ No newline at end of file diff --git a/java/osx/RELEASE.md b/java/osx/RELEASE.md new file mode 100644 index 0000000000..964392f623 --- /dev/null +++ b/java/osx/RELEASE.md @@ -0,0 +1,7 @@ +# Release 1.0.0-alpha +## Major Features and Improvements +* Support grpc synchronous transmission and streaming transmission. Compatible with eggroll interface and can replace FATE1. x rollsite component +* Support asynchronous message transmission, which can replace rabbitmq and pulsar components in FATE1. x +* Support HTTP1. X protocol transmission +* Support cluster deployment and inter-site traffic control +* Support networking as an Exchange component \ No newline at end of file diff --git a/java/osx/bin/common.sh b/java/osx/bin/common.sh index 3101efaa39..adc76cd21b 100644 --- a/java/osx/bin/common.sh +++ b/java/osx/bin/common.sh @@ -84,7 +84,6 @@ JAVA_OPT="${JAVA_OPT} -XX:-OmitStackTraceInFastThrow" JAVA_OPT="${JAVA_OPT} -XX:+AlwaysPreTouch" JAVA_OPT="${JAVA_OPT} -XX:MaxDirectMemorySize=15g" JAVA_OPT="${JAVA_OPT} -XX:-UseLargePages -XX:-UseBiasedLocking" -#JAVA_OPT="${JAVA_OPT} -Xdebug -Xrunjdwp:transport=dt_socket,address=9555,server=y,suspend=n" JAVA_OPT="${JAVA_OPT} ${JAVA_OPT_EXT}" set -e @@ -93,7 +92,7 @@ getpid() { pid=$(cat ./bin/broker.pid) fi if [[ -n ${pid} ]]; then - count=$(ps -ef | grep $pid | grep -v "grep" | wc -l) + count=$(ps -ef | grep $pid |grep 'com.osx' | grep -v "grep" | wc -l) if [[ ${count} -eq 0 ]]; then rm ./bin/broker.pid unset pid @@ -115,53 +114,50 @@ start() { getpid $module if [[ ! -n ${pid} ]]; then JAVA_OPT="${JAVA_OPT} " mklogsdir -# if [[ -e "${module}.jar" ]]; then -# rm ${module}.jar -# fi -# ln -s ${module}-${module_version}.jar ${module}.jar - JAVA_OPT="${JAVA_OPT} -cp conf/broker/:lib/*" -# if [ ${module} = "transfer" ]; then -# echo "transfer" -# elif [ ${module} = "cluster-manager" ] || [ ${module} = "dashboard" ]; then -# JAVA_OPT="${JAVA_OPT} -Dspring.config.location=${configpath}/cluster-manager.properties" -# JAVA_OPT="${JAVA_OPT} -cp conf/:lib/*:${module}.jar" -# else -# echo "usage: ${module} {transfer|cluster-manager|dashboard}" -# fi - + JAVA_OPT="${JAVA_OPT} -cp conf/broker/:lib/*:extension/*:${BASE_DIR}/${project_name}-${module}-${module_version}.jar" JAVA_OPT="${JAVA_OPT} ${main_class}" + JAVA_OPT="${JAVA_OPT} -c ${configpath} " + echo $JAVA ${JAVA_OPT} + nohup $JAVA ${JAVA_OPT} >/dev/null 2>&1 & + inspect_pid 5 $! + if [[ "$exist" = 1 ]]; then + echo $! >./bin/${module}.pid + getpid ${module} + echo "service start sucessfully. pid: ${pid}" + else + echo "service start failed, " + fi + else + echo "service already started. pid: ${pid}" + fi +} - JAVA_OPT="${JAVA_OPT} -c ${configpath}/broker/broker.properties" -# if [ ${module} = "broker" -o ${module} = "cli" ]; then -# JAVA_OPT="${JAVA_OPT} -c ${configpath}/broker/broker.properties" -# elif [ ${module} = "cluster-manager" ]; then -# JAVA_OPT="${JAVA_OPT} -c ${configpath}/cluster-manager/cluster-manager.properties" -# -# elif [ ${module} = "dashboard" ]; then -# JAVA_OPT="-jar ${libpath}/dashboard-1.0.0.jar -spring.config.location=${configpath}/dashboard/application.properties" -# fi - - if [ ${module} = "cli" ]; then - java ${JAVA_OPT} - else - nohup $JAVA ${JAVA_OPT} >/dev/null 2>&1 & - #sleep 5 - #id=$(ps -p $! | awk '{print $1}' | sed -n '2p') - inspect_pid 5 $! - - if [[ "$exist" = 1 ]]; then - echo $! >./bin/${module}.pid - getpid ${module} - echo "service start sucessfully. pid: ${pid}" - else - echo "service start failed" - fi - fi +debug() { + echo "try to start $1" + module=broker + main_class=com.osx.broker.Bootstrap + getpid $module + if [[ ! -n ${pid} ]]; then JAVA_OPT="${JAVA_OPT} " + mklogsdir + JAVA_OPT="${JAVA_OPT} -Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=n,address=8008 -cp conf/broker/:lib/*:extension/*:${BASE_DIR}/${project_name}-${module}-${module_version}.jar" + JAVA_OPT="${JAVA_OPT} ${main_class}" + JAVA_OPT="${JAVA_OPT} -c ${configpath} " + echo $JAVA ${JAVA_OPT} + nohup $JAVA ${JAVA_OPT} >/dev/null 2>&1 & + inspect_pid 5 $! + if [[ "$exist" = 1 ]]; then + echo $! >./bin/${module}.pid + getpid ${module} + echo "service start sucessfully. pid: ${pid}" + else + echo "service start failed, " + fi else echo "service already started. pid: ${pid}" fi } + status() { getpid $1 if [[ -n ${pid} ]]; then @@ -195,11 +191,9 @@ stop() { fi } - inspect_pid() { total=0 exist=0 - #echo "inspect pid: $2,periods: $1" if [[ -n $2 ]]; then while [[ $total -le $1 ]] do diff --git a/java/osx/bin/service.sh b/java/osx/bin/service.sh index 07f15ba4b4..c8de9e8fd6 100644 --- a/java/osx/bin/service.sh +++ b/java/osx/bin/service.sh @@ -25,8 +25,8 @@ configpath=$(cd $basepath/conf;pwd) libpath=$(cd $basepath/lib;pwd) #module=transfer #main_class=com.firework.transfer.Bootstrap -#module_version=1.0.0 - +module_version=1.0.0-alpha +project_name=osx @@ -35,6 +35,10 @@ case "$1" in start $2 status $2 ;; + debug) + debug $2 + status $2 + ;; stop) stop $2 ;; @@ -47,6 +51,12 @@ case "$1" in start $2 status $2 ;; + rebudeg) + stop $2 + sleep 0.5 + debug $2 + status $2 + ;; *) echo "usage: $0 {start|stop|status|restart}" exit 1 diff --git a/java/osx/broker/src/main/java/com/osx/broker/Bootstrap.java b/java/osx/broker/src/main/java/com/osx/broker/Bootstrap.java deleted file mode 100644 index 6c008efd23..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/Bootstrap.java +++ /dev/null @@ -1,165 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker; -import com.google.common.collect.Lists; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.Dict; -import com.osx.core.constant.StreamLimitMode; -import com.osx.core.jvm.JvmInfoCounter; -import com.osx.core.utils.JsonUtil; -import com.osx.core.utils.NetUtils; -import com.osx.core.utils.ServerUtil; -import org.apache.commons.cli.CommandLine; -import org.apache.commons.cli.Option; -import org.apache.commons.cli.Options; -import org.apache.commons.cli.PosixParser; -import org.apache.commons.lang3.StringUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.*; -import java.util.Properties; -public class Bootstrap { - static Logger logger = LoggerFactory.getLogger(Bootstrap.class); - static CommandLine commandLine; - public static void main(String[] args) { - try { - Options options = ServerUtil.buildCommandlineOptions(new Options()); - commandLine = ServerUtil.parseCmdLine("osx", args, buildCommandlineOptions(options), - new PosixParser()); - String filePath = commandLine.getOptionValue('c'); - logger.info("try to parse config file {}", filePath); - if (StringUtils.isEmpty(filePath)) { - System.err.println("config file is not set ,please use -c to set the config file path"); - System.exit(-1); - } - parseConfig(filePath); - Bootstrap bootstrap = new Bootstrap(); - bootstrap.start(args); - Thread shutDownThread = new Thread(() -> bootstrap.stop()); - Runtime.getRuntime().addShutdownHook(shutDownThread); - } catch (Exception ex) { - System.exit(1); - } - } - - private static Options buildCommandlineOptions(final Options options) { - Option opt = new Option("c", "configFile", true, "config properties file"); - opt.setRequired(false); - options.addOption(opt); - return options; - } - - public static void parseConfig(String configFilePath) { - try { - File file = new File(configFilePath); - Properties environment = new Properties(); - try (InputStream inputStream = new BufferedInputStream(new FileInputStream(file))) { - environment.load(inputStream); - } catch (FileNotFoundException e) { - logger.error("profile broker.properties not found"); - throw e; - } catch (IOException e) { - logger.error("parse config error, {}", e.getMessage()); - throw e; - } - MetaInfo.PROPERTY_FATE_TECH_PROVIDER = environment.getProperty(Dict.PROPERTY_FATE_TECH_PROVIDER,"FATE"); - MetaInfo.PROPERTY_ROOT_PATH = new File("").getCanonicalPath(); - MetaInfo.PROPERTY_ROUTE_TABLE = environment.getProperty(Dict.PROPERTY_ROUTE_TABLE); - MetaInfo.PROPERTY_SERVER_CERTCHAIN_FILE = environment.getProperty(Dict.PROPERTY_SERVER_CERTCHAIN_FILE); - MetaInfo.PROPERTY_SERVER_PRIVATEKEY_FILE = environment.getProperty(Dict.PROPERTY_SERVER_PRIVATEKEY_FILE); - MetaInfo.PROPERTY_SERVER_CA_FILE = environment.getProperty(Dict.PROPERTY_SERVER_CA_FILE); - MetaInfo.PROPERTY_GRPC_TLS_PORT = Integer.valueOf(environment.getProperty(Dict.PROPERTY_GRPC_TLS_PORT, "9883")); - MetaInfo.PROPERTY_GRPC_PORT = Integer.valueOf(environment.getProperty(Dict.PROPERTY_GRPC_PORT, "9889")); - MetaInfo.PROPERTY_HTTP_PORT = Integer.valueOf(environment.getProperty(Dict.HTTP_PORT,"8762")); - MetaInfo.PROPERTY_PRINT_INPUT_DATA = Boolean.valueOf(environment.getProperty(Dict.PROPERTY_PRINT_INPUT_DATA, "false")); - MetaInfo.PROPERTY_PRINT_OUTPUT_DATA = Boolean.valueOf(environment.getProperty(Dict.PROPERTY_PRINT_OUTPUT_DATA, "false")); - MetaInfo.PROPERTY_USER_HOME = System.getProperty("user.home"); - MetaInfo.PROPERTY_NEGOTIATIONTYPE = environment.getProperty(Dict.PROPERTY_NEGOTIATIONTYPE, "PLAINTEXT"); - MetaInfo.PROPERTY_TRANSFER_FILE_PATH_PRE = environment.getProperty(Dict.PROPERTY_TRANSFER_FILE_PATH, MetaInfo.PROPERTY_USER_HOME + "/.fate/transfer_file"); - MetaInfo.PROPERTY_TRANSFER_FILE_CACHE_SIZE = environment.getProperty(Dict.PROPERTY_TRANSFER_FILE_CACHE_SIZE) != null ? Integer.parseInt(environment.getProperty(Dict.PROPERTY_TRANSFER_FILE_CACHE_SIZE)) : 1 << 27; - MetaInfo.PROPERTY_USE_DIRECT_CACHE = Boolean.parseBoolean(environment.getProperty(Dict.PROPERTY_USE_DIRECT_CACHE, "false")); - MetaInfo.PROPERTY_MAX_TRANSFER_CACHE_SIZE = environment.getProperty(Dict.PROPERTY_MAX_TRANSFER_CACHE_SIZE) != null ? Integer.parseInt(environment.getProperty(Dict.PROPERTY_MAX_TRANSFER_CACHE_SIZE)) : 1 << 30; - MetaInfo.PROPERTY_GRPC_ONCOMPLETED_WAIT_TIMEOUT = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_ONCOMPLETED_WAIT_TIMEOUT, "600")); - // MetaInfo.PROPERTY_USE_QUEUE_MODEL = Boolean.valueOf(environment.getProperty(Dict.PROPERTY_USE_QUEUE_MODEL, "false")); - MetaInfo.PROPERTY_STREAM_LIMIT_MODE = environment.getProperty(Dict.PROPERTY_STREAM_LIMIT_MODE, StreamLimitMode.LOCAL.name()); - MetaInfo.PROPERTY_STREAM_LIMIT_MAX_TRY_TIME = Integer.parseInt(environment.getProperty(Dict.PROPERTY_STREAM_LIMIT_MAX_TRY_TIME, "10")); - MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION, "1000")); - MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE = environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE) != null ? Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE)) : (2 << 30) - 1; - MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE = environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE) != null ? Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE)) : 128 << 20; - MetaInfo.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW = environment.getProperty(Dict.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW) != null ? Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW)) : 128 << 20; - MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC, "7200")); - MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC, "3600")); - MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC, "120")); - MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED = Boolean.parseBoolean(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED, "false")); - MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC, "86400")); - MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC, "86400")); - MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC, "86400")); - MetaInfo.TRANSFER_FATECLOUD_AHTHENTICATION_ENABLED = Boolean.valueOf(environment.getProperty(Dict.TRANSFER_FATECLOUD_AHTHENTICATION_ENABLED, "false")); - MetaInfo.TRANSFER_FATECLOUD_AUTHENTICATION_USE_CONFIG = Boolean.valueOf(environment.getProperty(Dict.TRANSFER_FATECLOUD_AUTHENTICATION_USE_CONFIG, "false")); - MetaInfo.TRANSFER_FATECLOUD_AUTHENTICATION_ROLE = environment.getProperty(Dict.TRANSFER_FATECLOUD_AUTHENTICATION_ROLE, "guest"); - MetaInfo.TRANSFER_FATECLOUD_AUTHENTICATION_URI = environment.getProperty(Dict.TRANSFER_FATECLOUD_AUTHENTICATION_URI, "/cloud-manager/api/site/rollsite/checkPartyId"); - MetaInfo.TRANSFER_FATECLOUD_AUTHENTICATION_APPKEY = environment.getProperty(Dict.TRANSFER_FATECLOUD_AUTHENTICATION_APPKEY, ""); - MetaInfo.TRANSFER_FATECLOUD_AUTHENTICATION_APPSERCRET = environment.getProperty(Dict.TRANSFER_FATECLOUD_AUTHENTICATION_APPSERCRET, ""); - MetaInfo.TRANSFER_FATECLOUD_SECRET_INFO_URL = environment.getProperty(Dict.TRANSFER_FATECLOUD_SECRET_INFO_URL, "http://localhost:9091/fate-manager/api/site/secretinfo"); - MetaInfo.TRANSFER_FATECLOUD_AUTHENTICATION_URL = environment.getProperty(Dict.TRANSFER_FATECLOUD_AUTHENTICATION_URL, "http://localhost:8999/cloud-manager/api/site/rollsite/checkPartyId"); - MetaInfo.PROPERTY_SELF_PARTY.addAll(Lists.newArrayList(environment.getProperty(Dict.PROPERTY_SELF_PARTY, "").split(","))); - - MetaInfo.HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT,"500")); - MetaInfo.HTTP_CLIENT_CONFIG_CONN_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_CONFIG_CONN_TIME_OUT,"2000")); - MetaInfo.HTTP_CLIENT_CONFIG_SOCK_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_CONFIG_SOCK_TIME_OUT,"3000")); - MetaInfo.HTTP_CLIENT_INIT_POOL_MAX_TOTAL = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_INIT_POOL_MAX_TOTAL,"500")); - MetaInfo.HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE,"200")); - MetaInfo.HTTP_CLIENT_INIT_POOL_SOCK_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_INIT_POOL_SOCK_TIME_OUT,"10000")); - MetaInfo.HTTP_CLIENT_INIT_POOL_CONN_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_INIT_POOL_CONN_TIME_OUT,"10000")); - MetaInfo.HTTP_CLIENT_INIT_POOL_CONN_REQ_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_INIT_POOL_CONN_REQ_TIME_OUT,"10000")); - MetaInfo.HTTP_CLIENT_TRAN_CONN_REQ_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_TRAN_CONN_REQ_TIME_OUT,"60000")); - MetaInfo.HTTP_CLIENT_TRAN_CONN_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_TRAN_CONN_TIME_OUT,"60000")); - MetaInfo.HTTP_CLIENT_TRAN_SOCK_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_TRAN_SOCK_TIME_OUT,"60000")); - - MetaInfo.PRPPERTY_QUEUE_MAX_FREE_TIME = Integer.parseInt(environment.getProperty(Dict.PRPPERTY_QUEUE_MAX_FREE_TIME, "60000000")); - MetaInfo.INSTANCE_ID = NetUtils.getLocalHost() + ":" + MetaInfo.PROPERTY_GRPC_PORT; - MetaInfo.PROPERTY_DEPLOY_MODE = environment.getProperty(Dict.PROPERTY_DEPLOY_MODE); - MetaInfo.PROPERTY_CLUSTER_MANAGER_ADDRESS = environment.getProperty(Dict.PROPERTY_CLUSTER_MANAGER_ADDRESS); - MetaInfo.PROPERTY_EGGROLL_CLUSTER_MANANGER_IP = environment.getProperty(Dict.PROPERTY_EGGROLL_CLUSTER_MANANGER_IP); - MetaInfo.PROPERTY_EGGROLL_CLUSTER_MANANGER_PORT = Integer.parseInt(environment.getProperty(Dict.PROPERTY_EGGROLL_CLUSTER_MANANGER_PORT)); - MetaInfo.PROPERTY_ZK_URL = environment.getProperty(Dict.PROPERTY_ZK_URL); - MetaInfo.PROPERTY_OPEN_HTTP_SERVER = Boolean.valueOf(environment.getProperty(Dict.PROPERTY_OPEN_HTTP_SERVER, "false")); - MetaInfo.PROPERTY_OPEN_GRPC_TLS_SERVER = Boolean.valueOf(environment.getProperty(Dict.PROPERTY_OPEN_GRPC_TLS_SERVER, "false")); -// public static Boolean PROPERTY_OPEN_HTTP_SERVER = false; -// public static Boolean PROPERTY_OPEN_GRPC_TLS_SERVER = false; - MetaInfo.PROPERTY_DEFAULT_CLIENT_VERSION = environment.getProperty(Dict.PROPERTY_DEFAULT_CLIENT_VERSION,"2.X.X"); - - } catch (Exception e) { - logger.error("init MetaInfo error", e); - System.exit(1); - } - logger.info("Meta Info {}", JsonUtil.formatJson(JsonUtil.object2Json(MetaInfo.toMap()))); - } - - public void start(String[] args) { - ServiceContainer.init(); - JvmInfoCounter.start(); - } - - public void stop() { - logger.info("try to shutdown server ..."); - if (ServiceContainer.transferQueueManager != null) { - ServiceContainer.transferQueueManager.destroyAll(); - } - } - -} \ No newline at end of file diff --git a/java/osx/broker/src/main/java/com/osx/broker/grpc/PushRequestDataWrap.java b/java/osx/broker/src/main/java/com/osx/broker/grpc/PushRequestDataWrap.java deleted file mode 100644 index a03c4510ae..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/grpc/PushRequestDataWrap.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.grpc; - -import com.webank.ai.eggroll.api.networking.proxy.Proxy; -import io.grpc.stub.StreamObserver; - -public class PushRequestDataWrap { - Proxy.Packet packet; - StreamObserver streamObserver; - - public Proxy.Packet getPacket() { - return packet; - } - - public void setPacket(Proxy.Packet packet) { - this.packet = packet; - } - - public StreamObserver getStreamObserver() { - return streamObserver; - } - - public void setStreamObserver(StreamObserver streamObserver) { - this.streamObserver = streamObserver; - } -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/interceptor/RequestHandleInterceptor.java b/java/osx/broker/src/main/java/com/osx/broker/interceptor/RequestHandleInterceptor.java deleted file mode 100644 index 907375f4a4..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/interceptor/RequestHandleInterceptor.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.interceptor; - -import com.google.protobuf.InvalidProtocolBufferException; -import com.osx.broker.grpc.PushRequestDataWrap; -import com.osx.broker.router.FateRouterService; -import com.osx.core.context.Context; -import com.osx.core.exceptions.NoRouterInfoException; -import com.osx.core.exceptions.ParameterException; -import com.osx.core.router.RouterInfo; -import com.osx.core.service.InboundPackage; -import com.osx.core.service.Interceptor; -import com.webank.ai.eggroll.api.networking.proxy.Proxy; -import com.webank.eggroll.core.transfer.Transfer; -import org.apache.commons.lang3.StringUtils; -import org.ppc.ptp.Osx; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Map; - -public class RequestHandleInterceptor implements Interceptor { - Logger logger = LoggerFactory.getLogger(RequestHandleInterceptor.class); - - public void doPreProcess(Context context, InboundPackage inboundPackage) throws Exception { - Object body = inboundPackage.getBody(); - - if (body instanceof Osx.Inbound) { - Osx.Inbound request = (Osx.Inbound) body; - Map metaDataMap = request.getMetadataMap(); - String version = metaDataMap.get(Osx.Header.Version.name()); - String techProviderCode = metaDataMap.get(Osx.Header.TechProviderCode.name()); - String traceId = metaDataMap.get(Osx.Header.TraceID.name()); - String token = metaDataMap.get(Osx.Header.Token.name()); - String sourceNodeId = metaDataMap.get(Osx.Header.SourceNodeID.name()); - String targetNodeId = metaDataMap.get(Osx.Header.TargetNodeID.name()); - String sourceInstId = metaDataMap.get(Osx.Header.SourceInstID.name()); - String targetInstId = metaDataMap.get(Osx.Header.TargetInstID.name()); - String sessionId = metaDataMap.get(Osx.Header.SessionID.name()); - String targetMethod = metaDataMap.get(Osx.Metadata.TargetMethod.name()); - String targetComponentName = metaDataMap.get(Osx.Metadata.TargetComponentName.name()); - String sourceComponentName = metaDataMap.get(Osx.Metadata.SourceComponentName.name()); - String sourcePartyId = StringUtils.isEmpty(sourceInstId) ? sourceNodeId : sourceInstId + "." + sourceNodeId; - String targetPartyId = StringUtils.isEmpty(targetInstId) ? targetNodeId : targetInstId + "." + targetNodeId; - String topic = metaDataMap.get(Osx.Metadata.MessageTopic.name()); - String offsetString = metaDataMap.get(Osx.Metadata.MessageOffSet.name()); - Long offset = StringUtils.isNotEmpty(offsetString) ? Long.parseLong(offsetString) : null; - context.setDesPartyId(targetPartyId); - context.setSrcPartyId(sourcePartyId); - context.setTopic(topic); - context.setRequestMsgIndex(offset); - context.setSessionId(sessionId); - context.setDesComponent(targetComponentName); - context.setSrcComponent(sourceComponentName); - return; - } - else if (body instanceof PushRequestDataWrap) { - PushRequestDataWrap pushRequestDataWrap = (PushRequestDataWrap) body; - Proxy.Packet packet = pushRequestDataWrap.getPacket(); - handleProxyPacket(context ,packet); - return ; - }else if (body instanceof Proxy.Packet) { - handleProxyPacket(context ,(Proxy.Packet) body); - } else { - throw new ParameterException("invalid inbound type"); - } - - } - - private void handleProxyPacket(Context context ,Proxy.Packet packet){ - Proxy.Metadata metadata = packet.getHeader(); - Transfer.RollSiteHeader rollSiteHeader = null; - try { - rollSiteHeader = Transfer.RollSiteHeader.parseFrom(metadata.getExt()); - } catch (InvalidProtocolBufferException e) { - throw new ParameterException("invalid rollSiteHeader"); - } - String dstPartyId = rollSiteHeader.getDstPartyId(); - if (StringUtils.isEmpty(dstPartyId)) { - dstPartyId = metadata.getDst().getPartyId(); - } - - String desRole = metadata.getDst().getRole(); - String srcRole = metadata.getSrc().getRole(); - String srcPartyId = metadata.getSrc().getPartyId(); - context.setSrcPartyId(srcPartyId); - context.setDesPartyId(dstPartyId); - context.setSrcComponent(srcRole); - context.setDesComponent(desRole); - } - -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/MessageDecoder.java b/java/osx/broker/src/main/java/com/osx/broker/message/MessageDecoder.java deleted file mode 100644 index 5303e63303..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/message/MessageDecoder.java +++ /dev/null @@ -1,403 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.message; - -import com.osx.broker.constants.MessageFlag; -import com.osx.broker.util.MessageId; -import com.osx.broker.util.UtilAll; - -import java.net.*; -import java.nio.ByteBuffer; -import java.nio.charset.Charset; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -public class MessageDecoder { -// public final static int MSG_ID_LENGTH = 8 + 8; - - public final static Charset CHARSET_UTF8 = Charset.forName("UTF-8"); - public final static int MESSAGE_MAGIC_CODE_POSTION = 4; - public final static int MESSAGE_FLAG_POSTION = 16; - public final static int MESSAGE_PHYSIC_OFFSET_POSTION = 28; - // public final static int MESSAGE_STORE_TIMESTAMP_POSTION = 56; - public final static int MESSAGE_MAGIC_CODE = -626843481; - public static final char NAME_VALUE_SEPARATOR = 1; - public static final char PROPERTY_SEPARATOR = 2; - public static final int PHY_POS_POSITION = 4 + 4 + 4 + 4 + 4 + 8; - public static final int QUEUE_OFFSET_POSITION = 4 + 4 + 4 + 4 + 4; - public static final int SYSFLAG_POSITION = 4 + 4 + 4 + 4 + 4 + 8 + 8; - - - public static String createMessageId(final ByteBuffer input, final ByteBuffer addr, final long offset) { - input.flip(); - int msgIDLength = addr.limit() == 8 ? 16 : 28; - input.limit(msgIDLength); - - input.put(addr); - input.putLong(offset); - - return UtilAll.bytes2string(input.array()); - } - - public static MessageExtBrokerInner buildMessageExtBrokerInner(String topic, byte[] body, - int queueId, MessageFlag flag, String srcPartyId, String desPartyId) { - MessageExtBrokerInner messageExtBrokerInner = new MessageExtBrokerInner(); - messageExtBrokerInner.setQueueId(queueId); - messageExtBrokerInner.setBody(body); - messageExtBrokerInner.setTopic(topic); - messageExtBrokerInner.setFlag(flag.getFlag()); - messageExtBrokerInner.setBornTimestamp(System.currentTimeMillis()); - messageExtBrokerInner.setDesPartyId(srcPartyId); - messageExtBrokerInner.setSrcPartyId(desPartyId); - return messageExtBrokerInner; - } - - public static String createMessageId(SocketAddress socketAddress, long transactionIdhashCode) { - InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress; - int msgIDLength = inetSocketAddress.getAddress() instanceof Inet4Address ? 16 : 28; - ByteBuffer byteBuffer = ByteBuffer.allocate(msgIDLength); - byteBuffer.put(inetSocketAddress.getAddress().getAddress()); - byteBuffer.putInt(inetSocketAddress.getPort()); - byteBuffer.putLong(transactionIdhashCode); - byteBuffer.flip(); - return UtilAll.bytes2string(byteBuffer.array()); - } - - public static MessageId decodeMessageId(final String msgId) throws UnknownHostException { - SocketAddress address; - long offset; - int ipLength = msgId.length() == 32 ? 4 * 2 : 16 * 2; - - byte[] ip = UtilAll.string2bytes(msgId.substring(0, ipLength)); - byte[] port = UtilAll.string2bytes(msgId.substring(ipLength, ipLength + 8)); - ByteBuffer bb = ByteBuffer.wrap(port); - int portInt = bb.getInt(0); - address = new InetSocketAddress(InetAddress.getByAddress(ip), portInt); - - // offset - byte[] data = UtilAll.string2bytes(msgId.substring(ipLength + 8, ipLength + 8 + 16)); - bb = ByteBuffer.wrap(data); - offset = bb.getLong(0); - - return new MessageId(address, offset); - } - - /** - * Just decode properties from msg buffer. - * - * @param byteBuffer msg commit log buffer. - */ - public static Map decodeProperties(ByteBuffer byteBuffer) { - int sysFlag = byteBuffer.getInt(SYSFLAG_POSITION); - int bornhostLength = (sysFlag & MessageSysFlag.BORNHOST_V6_FLAG) == 0 ? 8 : 20; - int storehostAddressLength = (sysFlag & MessageSysFlag.STOREHOSTADDRESS_V6_FLAG) == 0 ? 8 : 20; - int bodySizePosition = 4 // 1 TOTALSIZE - + 4 // 2 MAGICCODE - + 4 // 3 BODYCRC - + 4 // 4 QUEUEID - + 4 // 5 FLAG - + 8 // 6 QUEUEOFFSET - + 8 // 7 PHYSICALOFFSET - + 4 // 8 SYSFLAG - + 8 // 9 BORNTIMESTAMP - + bornhostLength // 10 BORNHOST - + 8 // 11 STORETIMESTAMP - + storehostAddressLength // 12 STOREHOSTADDRESS - + 4 // 13 RECONSUMETIMES - + 8; // 14 Prepared Transaction Offset - int topicLengthPosition = bodySizePosition + 4 + byteBuffer.getInt(bodySizePosition); - - byte topicLength = byteBuffer.get(topicLengthPosition); - - short propertiesLength = byteBuffer.getShort(topicLengthPosition + 1 + topicLength); - - byteBuffer.position(topicLengthPosition + 1 + topicLength + 2); - - if (propertiesLength > 0) { - byte[] properties = new byte[propertiesLength]; - byteBuffer.get(properties); - String propertiesString = new String(properties, CHARSET_UTF8); - Map map = string2messageProperties(propertiesString); - return map; - } - return null; - } - - public static MessageExt decode(ByteBuffer byteBuffer) { - return decode(byteBuffer, true, true, false); - } - - public static MessageExt clientDecode(ByteBuffer byteBuffer, final boolean readBody) { - return decode(byteBuffer, readBody, true, true); - } - - public static MessageExt decode(ByteBuffer byteBuffer, final boolean readBody) { - return decode(byteBuffer, readBody, true, false); - } - - public static MessageExt decode( - ByteBuffer byteBuffer, final boolean readBody, final boolean deCompressBody) { - return decode(byteBuffer, readBody, deCompressBody, false); - } - - public static MessageExt decode( - ByteBuffer byteBuffer, final boolean readBody, final boolean deCompressBody, final boolean isClient) { - try { - - MessageExt msgExt= new MessageExt(); - // 1 TOTALSIZE - int storeSize = byteBuffer.getInt(); - msgExt.setStoreSize(storeSize); - - // 2 MAGICCODE - byteBuffer.getInt(); - - // 3 BODYCRC - int bodyCRC = byteBuffer.getInt(); - msgExt.setBodyCRC(bodyCRC); - - // 4 QUEUEID - int queueId = byteBuffer.getInt(); - msgExt.setQueueId(queueId); - - // 5 FLAG - int flag = byteBuffer.getInt(); - msgExt.setFlag(flag); - - // 6 QUEUEOFFSET - int srcPartyIdLength = byteBuffer.get(); - if (srcPartyIdLength > 0) { - byte[] srcPartyBytes = new byte[srcPartyIdLength]; - byteBuffer.get(srcPartyBytes); - String srcPartyId = new String(srcPartyBytes); - msgExt.setSrcPartyId(srcPartyId); - } - -// long queueOffset = byteBuffer.getLong(); -// msgExt.setQueueOffset(queueOffset); - - // 7 PHYSICALOFFSET -// long physicOffset = byteBuffer.getLong(); -// msgExt.setCommitLogOffset(physicOffset); - - - int desPartyIdLength = byteBuffer.get(); - if (desPartyIdLength > 0) { - byte[] desPartyIdBytes = new byte[desPartyIdLength]; - byteBuffer.get(desPartyIdBytes); - String desPartyId = new String(desPartyIdBytes); - msgExt.setDesPartyId(desPartyId); - } - - - // 8 SYSFLAG - int sysFlag = byteBuffer.getInt(); - msgExt.setSysFlag(sysFlag); - - // 9 BORNTIMESTAMP - long bornTimeStamp = byteBuffer.getLong(); - msgExt.setBornTimestamp(bornTimeStamp); - - - // 15 BODY - int bodyLen = byteBuffer.getInt(); - if (bodyLen > 0) { - if (readBody) { - byte[] body = new byte[bodyLen]; - byteBuffer.get(body); - msgExt.setBody(body); - } else { - byteBuffer.position(byteBuffer.position() + bodyLen); - } - } - - // 16 TOPIC - short topicLen = byteBuffer.getShort(); - byte[] topic = new byte[(int) topicLen]; - byteBuffer.get(topic); - msgExt.setTopic(new String(topic, CHARSET_UTF8)); - - // 17 properties - short propertiesLength = byteBuffer.getShort(); - if (propertiesLength > 0) { - byte[] properties = new byte[propertiesLength]; - byteBuffer.get(properties); - String propertiesString = new String(properties, CHARSET_UTF8); - Map map = string2messageProperties(propertiesString); - msgExt.setProperties(map); - } - - return msgExt; - } catch (Exception e) { - e.printStackTrace(); - byteBuffer.position(byteBuffer.limit()); - } - - return null; - } - - public static List decodes(ByteBuffer byteBuffer) { - return decodes(byteBuffer, true); - } - - public static List decodes(ByteBuffer byteBuffer, final boolean readBody) { - List msgExts = new ArrayList(); - while (byteBuffer.hasRemaining()) { - MessageExt msgExt = clientDecode(byteBuffer, readBody); - if (null != msgExt) { - msgExts.add(msgExt); - } else { - break; - } - } - return msgExts; - } - - public static String messageProperties2String(Map properties) { - StringBuilder sb = new StringBuilder(); - if (properties != null) { - for (final Map.Entry entry : properties.entrySet()) { - final String name = entry.getKey(); - final String value = entry.getValue(); - - if (value == null) { - continue; - } - sb.append(name); - sb.append(NAME_VALUE_SEPARATOR); - sb.append(value); - sb.append(PROPERTY_SEPARATOR); - } - } - return sb.toString(); - } - - public static Map string2messageProperties(final String properties) { - Map map = new HashMap(); - if (properties != null) { - String[] items = properties.split(String.valueOf(PROPERTY_SEPARATOR)); - for (String i : items) { - String[] nv = i.split(String.valueOf(NAME_VALUE_SEPARATOR)); - if (2 == nv.length) { - map.put(nv[0], nv[1]); - } - } - } - - return map; - } - - public static byte[] encodeMessage(Message message) { - //only need flag, body, properties - byte[] body = message.getBody(); - int bodyLen = body.length; - String properties = messageProperties2String(message.getProperties()); - byte[] propertiesBytes = properties.getBytes(CHARSET_UTF8); - //note properties length must not more than Short.MAX - short propertiesLength = (short) propertiesBytes.length; - int sysFlag = message.getFlag(); - int storeSize = 4 // 1 TOTALSIZE - + 4 // 2 MAGICCOD - + 4 // 3 BODYCRC - + 4 // 4 FLAG - + 4 + bodyLen // 4 BODY - + 2 + propertiesLength; - ByteBuffer byteBuffer = ByteBuffer.allocate(storeSize); - // 1 TOTALSIZE - byteBuffer.putInt(storeSize); - - // 2 MAGICCODE - byteBuffer.putInt(0); - - // 3 BODYCRC - byteBuffer.putInt(0); - - // 4 FLAG - int flag = message.getFlag(); - byteBuffer.putInt(flag); - - // 5 BODY - byteBuffer.putInt(bodyLen); - byteBuffer.put(body); - - // 6 properties - byteBuffer.putShort(propertiesLength); - byteBuffer.put(propertiesBytes); - - return byteBuffer.array(); - } - - public static Message decodeMessage(ByteBuffer byteBuffer) throws Exception { - Message message = new Message(); - - // 1 TOTALSIZE - byteBuffer.getInt(); - - // 2 MAGICCODE - byteBuffer.getInt(); - - // 3 BODYCRC - byteBuffer.getInt(); - - // 4 FLAG - int flag = byteBuffer.getInt(); - message.setFlag(flag); - - // 5 BODY - int bodyLen = byteBuffer.getInt(); - byte[] body = new byte[bodyLen]; - byteBuffer.get(body); - message.setBody(body); - - // 6 properties - short propertiesLen = byteBuffer.getShort(); - byte[] propertiesBytes = new byte[propertiesLen]; - byteBuffer.get(propertiesBytes); - message.setProperties(string2messageProperties(new String(propertiesBytes, CHARSET_UTF8))); - - return message; - } - - public static byte[] encodeMessages(List messages) { - //TO DO refactor, accumulate in one buffer, avoid copies - List encodedMessages = new ArrayList(messages.size()); - int allSize = 0; - for (Message message : messages) { - byte[] tmp = encodeMessage(message); - encodedMessages.add(tmp); - allSize += tmp.length; - } - byte[] allBytes = new byte[allSize]; - int pos = 0; - for (byte[] bytes : encodedMessages) { - System.arraycopy(bytes, 0, allBytes, pos, bytes.length); - pos += bytes.length; - } - return allBytes; - } - - public static List decodeMessages(ByteBuffer byteBuffer) throws Exception { - //TO DO add a callback for processing, avoid creating lists - List msgs = new ArrayList(); - while (byteBuffer.hasRemaining()) { - Message msg = decodeMessage(byteBuffer); - msgs.add(msg); - } - return msgs; - } -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpClusterTopicApplyService.java b/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpClusterTopicApplyService.java deleted file mode 100644 index 4a3183b05a..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpClusterTopicApplyService.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.ptp; - -import com.osx.broker.ServiceContainer; -import com.osx.core.constant.ActionType; -import com.osx.core.context.Context; -import com.osx.core.exceptions.ParameterException; -import com.osx.core.service.InboundPackage; -import org.apache.commons.lang3.StringUtils; -import org.ppc.ptp.Osx; - - -public class PtpClusterTopicApplyService extends AbstractPtpServiceAdaptor { - @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { - context.setActionType(ActionType.TOPIC_APPLY.getAlias()); - Osx.Inbound inbound = data.getBody(); - String topic = inbound.getMetadataMap().get(Osx.Metadata.MessageTopic.name()); - String instanceId = inbound.getMetadataMap().get(Osx.Metadata.InstanceId.name()); - String sessionId = inbound.getMetadataMap().get(Osx.Header.SessionID.name()); - if(StringUtils.isEmpty(topic)) - { - throw new ParameterException("topic is null"); - } - if(StringUtils.isEmpty(instanceId)) - { - throw new ParameterException("instanceId is null"); - } - if(StringUtils.isEmpty(sessionId)) - { - throw new ParameterException("sessionId is null"); - } - context.setTopic(topic); - context.setSessionId(sessionId); - Osx.Outbound outbound = ServiceContainer.transferQueueManager.applyFromMaster( topic,sessionId,instanceId); - return outbound; - } - -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpConsumeService.java b/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpConsumeService.java deleted file mode 100644 index 7f67d66035..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpConsumeService.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.ptp; - -import com.google.common.base.Preconditions; -import com.osx.broker.ServiceContainer; -import com.osx.broker.consumer.UnaryConsumer; -import com.osx.broker.queue.CreateQueueResult; -import com.osx.broker.queue.TransferQueue; -import com.osx.broker.queue.TransferQueueApplyInfo; -import com.osx.broker.util.TransferUtil; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.ActionType; -import com.osx.core.constant.Dict; -import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; -import com.osx.core.exceptions.ParameterException; -import com.osx.core.exceptions.TransferQueueNotExistException; -import com.osx.core.frame.GrpcConnectionFactory; -import com.osx.core.router.RouterInfo; -import com.osx.core.service.InboundPackage; -import io.grpc.ManagedChannel; -import io.grpc.stub.StreamObserver; -import org.ppc.ptp.Osx; -import org.ppc.ptp.PrivateTransferProtocolGrpc; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class PtpConsumeService extends AbstractPtpServiceAdaptor { - - Logger logger = LoggerFactory.getLogger(PtpConsumeService.class); - - public PtpConsumeService() { - this.setServiceName("consume-unary"); - } - - @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { - context.setActionType(ActionType.DEFUALT_CONSUME.getAlias()); - Osx.Inbound inbound = data.getBody(); - String topic = context.getTopic(); - TransferQueue transferQueue = ServiceContainer.transferQueueManager.getQueue(topic); - if (transferQueue == null) { - - if(MetaInfo.isCluster()) { - TransferQueueApplyInfo transferQueueApplyInfo = ServiceContainer.transferQueueManager.queryGlobleQueue(topic); - if (transferQueueApplyInfo == null) { - throw new TransferQueueNotExistException(); - } else { - String[] args = transferQueueApplyInfo.getInstanceId().split(":"); - String ip = args[0]; - int port = Integer.parseInt(args[1]); - RouterInfo routerInfo = new RouterInfo(); - routerInfo.setHost(ip); - routerInfo.setPort(port); - context.setRouterInfo(routerInfo); - return redirect(context, inbound); - } - }else{ - /** - * 单机版直接创建队列 - */ - logger.warn("create topic {} by consume request ",topic); - CreateQueueResult createQueueResult = ServiceContainer.transferQueueManager.createNewQueue( topic, context.getSessionId(), true); - if(createQueueResult.getTransferQueue()==null){ - throw new TransferQueueNotExistException(); - } - } - } - StreamObserver streamObserver = (StreamObserver) context.getData(Dict.RESPONSE_STREAM_OBSERVER); - Long offset = context.getRequestMsgIndex(); - Preconditions.checkArgument(offset != null); - if(offset==null){ - throw new ParameterException("offset is null"); - } - if (offset > 0) { - context.setActionType(ActionType.CUSTOMER_CONSUME.getAlias()); - } - UnaryConsumer consumer = ServiceContainer.consumerManager.getOrCreateUnaryConsumer(topic); - TransferQueue.TransferQueueConsumeResult transferQueueConsumeResult = consumer.consume(context, offset); - context.setReturnCode(transferQueueConsumeResult.getCode()); - if (transferQueueConsumeResult.getCode().equals(StatusCode.CONSUME_NO_MESSAGE)) { - /* - * 由其他扫描线程应答 - */ - if (offset < 0) { - UnaryConsumer.LongPullingHold longPullingHold = new UnaryConsumer.LongPullingHold(); - longPullingHold.setNeedOffset(offset); - longPullingHold.setStreamObserver(streamObserver); - longPullingHold.setContext(context.subContext()); - consumer.addLongPullingQueue(longPullingHold); - return null; - } - } - Osx.Outbound consumeResponse = TransferUtil.buildResponse(transferQueueConsumeResult.getCode(), "", transferQueueConsumeResult); - return consumeResponse; - - } - - private Osx.Outbound redirect(Context context, Osx.Inbound inbound) { - ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(context.getRouterInfo(),true); - context.setActionType(ActionType.REDIRECT_CONSUME.getAlias()); - PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub stub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); - return stub.invoke(inbound); - } - - -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpProduceService.java b/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpProduceService.java deleted file mode 100644 index d832176a2c..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpProduceService.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.ptp; - -import com.osx.broker.ServiceContainer; -import com.osx.broker.constants.MessageFlag; -import com.osx.broker.message.MessageDecoder; -import com.osx.broker.message.MessageExtBrokerInner; -import com.osx.broker.queue.CreateQueueResult; -import com.osx.broker.queue.PutMessageResult; -import com.osx.broker.queue.PutMessageStatus; -import com.osx.broker.queue.TransferQueue; -import com.osx.broker.util.TransferUtil; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.ActionType; -import com.osx.core.constant.DeployMode; -import com.osx.core.constant.Dict; -import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; -import com.osx.core.exceptions.*; -import com.osx.core.router.RouterInfo; -import com.osx.core.service.InboundPackage; -import org.apache.commons.lang3.StringUtils; -import org.ppc.ptp.Osx; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import static com.osx.broker.util.TransferUtil.redirect; - -public class PtpProduceService extends AbstractPtpServiceAdaptor { - - Logger logger = LoggerFactory.getLogger(PtpProduceService.class); - - @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { - - String topic = context.getTopic(); - boolean isDst = false; - RouterInfo routerInfo = context.getRouterInfo(); - String srcPartyId = context.getSrcPartyId(); - String sessionId = context.getSessionId(); - Osx.Inbound produceRequest = data.getBody(); - if (MetaInfo.PROPERTY_SELF_PARTY.contains(context.getDesPartyId())) { - isDst = true; - } - if (!isDst) { - /** - * 向外转发 - */ - return redirect(context, produceRequest, routerInfo, false); - } else { - /** - * 本地处理 - */ - if (StringUtils.isEmpty(topic)) { - throw new ParameterException(StatusCode.PARAM_ERROR, "topic is null"); - } - if (StringUtils.isEmpty(sessionId)) { - throw new ParameterException(StatusCode.PARAM_ERROR, "sessionId is null"); - } - context.setActionType(ActionType.MSG_DOWNLOAD.getAlias()); - context.setRouterInfo(null); - TransferQueue transferQueue = ServiceContainer.transferQueueManager.getQueue(topic); - CreateQueueResult createQueueResult = null; - if( transferQueue==null) { - createQueueResult = ServiceContainer.transferQueueManager.createNewQueue(topic, sessionId, false); - if (createQueueResult == null) { - throw new CreateTopicErrorException("create topic " + topic + " error"); - } - transferQueue = createQueueResult.getTransferQueue(); - } - String resource = TransferUtil.buildResource(produceRequest); - int dataSize = produceRequest.getSerializedSize(); - ServiceContainer.tokenApplyService.applyToken(context,resource,dataSize); - ServiceContainer.flowCounterManager.pass(resource,dataSize); - if (transferQueue != null) { - byte[] msgBytes = produceRequest.getPayload().toByteArray(); - MessageExtBrokerInner messageExtBrokerInner = MessageDecoder.buildMessageExtBrokerInner(topic, msgBytes, 0, MessageFlag.MSG, context.getSrcPartyId(), - context.getDesPartyId()); - PutMessageResult putMessageResult = transferQueue.putMessage(messageExtBrokerInner); - if (putMessageResult.getPutMessageStatus() != PutMessageStatus.PUT_OK) { - throw new PutMessageException("put status " + putMessageResult.getPutMessageStatus()); - } - long logicOffset = putMessageResult.getMsgLogicOffset(); - context.setCurrentMsgIndex(logicOffset); - Osx.Outbound.Builder outBoundBuilder = Osx.Outbound.newBuilder(); - outBoundBuilder.setCode(StatusCode.SUCCESS); - outBoundBuilder.setMessage(Dict.SUCCESS); - return outBoundBuilder.build(); - } else { - /** - * 集群内转发 - */ - if (MetaInfo.PROPERTY_DEPLOY_MODE.equals(DeployMode.cluster.name())) { - RouterInfo redirectRouterInfo = new RouterInfo(); - String redirectIp = createQueueResult.getRedirectIp(); - int redirectPort = createQueueResult.getPort(); - if (StringUtils.isEmpty(redirectIp) || redirectPort == 0) { - logger.error("invalid redirect info {}:{}", redirectIp, redirectPort); - throw new InvalidRedirectInfoException(); - } - redirectRouterInfo.setHost(redirectIp); - redirectRouterInfo.setPort(redirectPort); - context.setRouterInfo(redirectRouterInfo); - context.setActionType(ActionType.INNER_REDIRECT.getAlias()); - return redirect(context, produceRequest, redirectRouterInfo, true); - } else { - logger.error("create topic {} error", topic); - throw new ProduceMsgExcption(); - } - } - } - } - - -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/router/DefaultFateRouterServiceImpl.java b/java/osx/broker/src/main/java/com/osx/broker/router/DefaultFateRouterServiceImpl.java deleted file mode 100644 index e09f9ea1ec..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/router/DefaultFateRouterServiceImpl.java +++ /dev/null @@ -1,260 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.router; -import com.google.common.base.Preconditions; -import com.google.common.collect.Maps; -import com.google.gson.JsonArray; -import com.google.gson.JsonElement; -import com.google.gson.JsonObject; -import com.google.gson.JsonParser; -import com.google.protobuf.InvalidProtocolBufferException; -import com.osx.core.constant.NegotiationType; -import com.osx.core.datasource.FileRefreshableDataSource; -import com.osx.core.flow.PropertyListener; -import com.osx.core.router.RouterInfo; -import com.osx.core.utils.JsonUtil; -import com.webank.ai.eggroll.api.networking.proxy.Proxy; -import com.webank.eggroll.core.transfer.Transfer; -import org.apache.commons.lang3.StringUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; -import java.io.FileNotFoundException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -public class DefaultFateRouterServiceImpl implements FateRouterService { - - private static final String IP = "ip"; - private static final String PORT = "port"; - private static final String URL = "url"; - private static final String USE_SSL = "useSSL"; - private static final String HOSTNAME = "hostname"; - private static final String negotiationType = "negotiationType"; - private static final String certChainFile = "certChainFile"; - private static final String privateKeyFile = "privateKeyFile"; - private static final String caFile = "caFile"; - private static final String DEFAULT = "default"; - private static final String VERSION = "version"; - Logger logger = LoggerFactory.getLogger(DefaultFateRouterServiceImpl.class); - Map> routerInfoMap = new ConcurrentHashMap>(); - Map>> endPointMap = new ConcurrentHashMap<>(); - FileRefreshableDataSource fileRefreshableDataSource; - - @Override - public RouterInfo route(Proxy.Packet packet) { - Preconditions.checkArgument(packet != null); - RouterInfo routerInfo = null; - Proxy.Metadata metadata = packet.getHeader(); - Transfer.RollSiteHeader rollSiteHeader = null; - try { - rollSiteHeader = Transfer.RollSiteHeader.parseFrom(metadata.getExt()); - } catch (InvalidProtocolBufferException e) { - e.printStackTrace(); - } - String dstPartyId = rollSiteHeader.getDstPartyId(); - - if (StringUtils.isEmpty(dstPartyId)) { - dstPartyId = metadata.getDst().getPartyId(); - } - dstPartyId = metadata.getDst().getPartyId(); - String desRole = metadata.getDst().getRole(); - String srcRole = metadata.getSrc().getRole(); - String srcPartyId = metadata.getSrc().getPartyId(); - routerInfo = this.route(srcPartyId, srcRole, dstPartyId, desRole); - //logger.info("query router info {} to {} {} return {}", srcPartyId, dstPartyId, desRole, routerInfo); - return routerInfo; - } - - - public RouterInfo route(String srcPartyId, String srcRole, String dstPartyId, String desRole) { - RouterInfo routerInfo = null; - Map> partyIdMap = this.endPointMap.get(dstPartyId); - if (partyIdMap != null) { - - if (StringUtils.isNotEmpty(desRole)&&partyIdMap.get(desRole) != null) { - List ips = partyIdMap.getOrDefault(desRole, null); - if (ips != null && ips.size() > 0) { - Map endpoint = ips.get((int) (System.currentTimeMillis() % ips.size())); - routerInfo = new RouterInfo(); - routerInfo.setHost(endpoint.get(IP).toString()); - routerInfo.setPort(((Number) endpoint.get(PORT)).intValue()); - routerInfo.setDesPartyId(dstPartyId); - routerInfo.setSourcePartyId(srcPartyId); - routerInfo.setVersion(endpoint.get(VERSION) != null ? endpoint.get(VERSION).toString() : null); - routerInfo.setNegotiationType(endpoint.get(negotiationType)!=null?endpoint.get(negotiationType).toString():""); - } - } else { - - List ips = partyIdMap.getOrDefault(DEFAULT, null); - if (ips != null && ips.size() > 0) { - Map endpoint = ips.get((int) (System.currentTimeMillis() % ips.size())); - routerInfo = new RouterInfo(); - routerInfo.setHost(endpoint.get(IP).toString()); - routerInfo.setPort(((Number) endpoint.get(PORT)).intValue()); - routerInfo.setDesPartyId(dstPartyId); - routerInfo.setSourcePartyId(srcPartyId); - routerInfo.setVersion(endpoint.get(VERSION) != null ? endpoint.get(VERSION).toString() : null); - routerInfo.setNegotiationType(endpoint.get(negotiationType)!=null?endpoint.get(negotiationType).toString():""); - } - if(StringUtils.isNotEmpty(desRole)){ - logger.warn("role {} is not found,return default router info ",desRole); - } - } - } - return routerInfo; - } - - - - Map>> initRouteTable(Map confJson) { - // BasicMeta.Endpoint.Builder endpointBuilder = BasicMeta.Endpoint.newBuilder(); - Map>> newRouteTable = new ConcurrentHashMap<>(); - // loop through coordinator - - confJson.forEach((k,v)->{ - String coordinatorKey = k.toString(); - Map coordinatorValue = (Map)v; - - Map> serviceTable = newRouteTable.get(coordinatorKey); - if (serviceTable == null) { - serviceTable = new ConcurrentHashMap<>(4); - newRouteTable.put(coordinatorKey, serviceTable); - } - // loop through role in coordinator - for (Object roleEntryObject : coordinatorValue.entrySet()) { - Map.Entry roleEntry = (Map.Entry)roleEntryObject; - String roleKey = roleEntry.getKey().toString(); - if (roleKey.equals("createTime") || roleKey.equals("updateTime")) { - continue; - } - List roleValue = (List)roleEntry.getValue(); - - List endpoints = serviceTable.get(roleKey); - if (endpoints == null) { - endpoints = new ArrayList<>(); - serviceTable.put(roleKey, endpoints); - } - - // loop through endpoints - for (Object endpointElement : roleValue) { - - Map element = Maps.newHashMap(); - - Map endpointJson = (Map)endpointElement; - - if (endpointJson.get(IP)!=null) { - String targetIp = endpointJson.get(IP).toString(); - element.put(IP, targetIp); - } - - if (endpointJson.get(PORT)!=null) { - int targetPort = Integer.parseInt(endpointJson.get(PORT).toString()); - element.put(PORT, targetPort); - } -// if(endpointJson.has(URL)){ -// String url = endpointJson.get(URL).getAsString(); -// endpointBuilder.setUrl(url); -// } - - if (endpointJson.get(USE_SSL)!=null) { - boolean targetUseSSL = Boolean.getBoolean(endpointJson.get(USE_SSL).toString()); - element.put(USE_SSL, targetUseSSL); - } - - if (endpointJson.get(HOSTNAME)!=null) { - String targetHostname = endpointJson.get(HOSTNAME).toString(); - element.put(HOSTNAME, targetHostname); - } - - if (endpointJson.get(negotiationType)!=null) { - String targetNegotiationType = endpointJson.get(negotiationType).toString(); - element.put(negotiationType, targetNegotiationType); - }else{ - element.put(negotiationType, NegotiationType.PLAINTEXT); - } - - if (endpointJson.get(certChainFile)!=null) { - String targetCertChainFile = endpointJson.get(certChainFile).toString(); - element.put(certChainFile, targetCertChainFile); - } - - if (endpointJson.get(privateKeyFile)!=null) { - String targetPrivateKeyFile = endpointJson.get(privateKeyFile).toString(); - element.put(privateKeyFile, targetPrivateKeyFile); - } - - if (endpointJson.get(caFile)!=null) { - String targetCaFile = endpointJson.get(caFile).toString(); - element.put(caFile, targetCaFile); - } - if (endpointJson.get(VERSION)!=null) { - String targetVersion = endpointJson.get(VERSION).toString(); - element.put(VERSION, targetVersion); - } - - //BasicMeta.Endpoint endpoint = endpointBuilder.build(); - endpoints.add(element); - } - } - - }); - - return newRouteTable; - } - - public void start() { - String currentPath = Thread.currentThread().getContextClassLoader().getResource("route_table.json").getPath(); - logger.info("load router file {}", currentPath); - File confFile = new File(currentPath); - FileRefreshableDataSource fileRefreshableDataSource = null; - try { - fileRefreshableDataSource = new FileRefreshableDataSource(confFile, (source) -> { - logger.info("read route_table {}", source); - return source; - }); - fileRefreshableDataSource.getProperty().addListener(new RouterTableListener()); - - } catch (FileNotFoundException e) { - logger.error("router file {} is not found", currentPath); - } - } - - private class RouterTableListener implements PropertyListener { - - @Override - public void configUpdate(String value) { - // logger.info("fire router table update {}",value); - Map confJson = JsonUtil.json2Object(value,Map.class); - // JsonObject confJson = JsonParser.parseString(value).getAsJsonObject(); - Map content =(Map) confJson.get("route_table"); - endPointMap = initRouteTable(content); - } - - @Override - public void configLoad(String value) { - Map confJson = JsonUtil.json2Object(value,Map.class); - Map content =(Map) confJson.get("route_table"); - endPointMap = initRouteTable(content); - logger.info("load router config {}", JsonUtil.formatJson(JsonUtil.object2Json(endPointMap))); - } - } - - -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/server/OsxServer.java b/java/osx/broker/src/main/java/com/osx/broker/server/OsxServer.java deleted file mode 100644 index b78a399b6f..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/server/OsxServer.java +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.server; -import com.osx.broker.ServiceContainer; -import com.osx.broker.grpc.ContextPrepareInterceptor; -import com.osx.broker.grpc.ServiceExceptionHandler; -import com.osx.broker.http.DispatchServlet; -import com.osx.core.config.MetaInfo; -import io.grpc.ServerBuilder; -import io.grpc.ServerInterceptors; -import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; -import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; -import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth; -import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder; -import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider; -import org.apache.commons.lang3.StringUtils; -import org.eclipse.jetty.server.*; -import org.eclipse.jetty.servlet.ServletContextHandler; -import org.eclipse.jetty.servlet.ServletHolder; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import javax.net.ssl.SSLException; -import java.io.File; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; - -import static com.osx.core.config.MetaInfo.PROPERTY_OPEN_GRPC_TLS_SERVER; - -/** - * http1.X + grpc - */ -public class OsxServer { - - - Logger logger = LoggerFactory.getLogger(OsxServer.class); - io.grpc.Server server; - io.grpc.Server tlsServer; - org.eclipse.jetty.server.Server httpServer; - - private void init() { - server = buildServer(); - if(MetaInfo.PROPERTY_OPEN_HTTP_SERVER) { - httpServer = buildHttpServer(); - } - - // tlsServer = buildTlsServer(); - } - - public Server buildHttpServer(){ - Server server = new Server(); - try { - int acceptors = 1; - int selectors = 1; - ServerConnector connector = new ServerConnector(server, acceptors, selectors, new HttpConnectionFactory()); - // logger.info("http server try to start listen port {}", MetaInfo.PROPERTY_HTTP_PORT); - connector.setPort(MetaInfo.PROPERTY_HTTP_PORT); - connector.setHost("127.0.0.1"); - connector.setAcceptQueueSize(128); - server.addConnector(connector); - server.setHandler(buildServlet()); - return server; - } catch (Exception e) { - logger.error("build http server error",e); - } - return null; - } - - ServletContextHandler buildServlet(){ - ServletContextHandler context = new ServletContextHandler(); - context.setContextPath(MetaInfo.PROPERTY_HTTP_CONTEXT_PATH); - ServletHolder servletHolder = context.addServlet(DispatchServlet.class, MetaInfo.PROPERTY_HTTP_SERVLET_PATH);//"/*" - return context; - } - - - - public boolean start() { - - init(); - try { - server.start(); - logger.info("listen grpc port {} success", MetaInfo.PROPERTY_GRPC_PORT); - } catch (Exception e) { - if (e instanceof java.net.BindException||e.getCause() instanceof java.net.BindException) { - logger.error("port {} already in use, please try to choose another one !!!!", MetaInfo.PROPERTY_GRPC_PORT); - } - return false; - } - try{ - if(httpServer!=null){ - - httpServer.start(); - logger.info("listen http port {} success", MetaInfo.PROPERTY_HTTP_PORT); - } - } - catch (Exception e) { - if (e instanceof java.net.BindException||e.getCause() instanceof java.net.BindException) { - logger.error("port {} already in use, please try to choose another one !!!!", MetaInfo.PROPERTY_GRPC_PORT); - } - return false; - } - try{ - if (tlsServer != null) { - logger.info("grpc tls server try to start, listen port {}", MetaInfo.PROPERTY_GRPC_TLS_PORT); - tlsServer.start(); - logger.info("listen grpc tls port {} success", MetaInfo.PROPERTY_GRPC_TLS_PORT); - } - } catch (Exception e) { - if (e instanceof java.net.BindException||e.getCause() instanceof java.net.BindException) { - logger.error("port {} already in use, please try to choose another one !!!!", MetaInfo.PROPERTY_GRPC_TLS_PORT); - } - return false; - } - return true; - } - - private io.grpc.Server buildTlsServer(){ - String certChainFilePath = MetaInfo.PROPERTY_SERVER_CERTCHAIN_FILE; - String privateKeyFilePath = MetaInfo.PROPERTY_SERVER_PRIVATEKEY_FILE; - String trustCertCollectionFilePath = MetaInfo.PROPERTY_SERVER_CA_FILE; - - if(PROPERTY_OPEN_GRPC_TLS_SERVER && StringUtils.isNotBlank(certChainFilePath) - && StringUtils.isNotBlank(privateKeyFilePath) && StringUtils.isNotBlank(trustCertCollectionFilePath)) { - try { - int port = MetaInfo.PROPERTY_GRPC_TLS_PORT; - NettyServerBuilder serverBuilder = (NettyServerBuilder) ServerBuilder.forPort(port); - SslContextBuilder sslContextBuilder = GrpcSslContexts.forServer(new File(certChainFilePath), new File(privateKeyFilePath)) - .trustManager(new File(trustCertCollectionFilePath)) - .clientAuth(ClientAuth.REQUIRE) - .sessionTimeout(3600 << 4) - .sessionCacheSize(65536); - GrpcSslContexts.configure(sslContextBuilder, SslProvider.OPENSSL); - serverBuilder.sslContext(sslContextBuilder.build()); - logger.info("running in secure mode. server crt path: {}, server key path: {}, ca crt path: {}.", - certChainFilePath, privateKeyFilePath, trustCertCollectionFilePath); - //serverBuilder.executor(executor); - serverBuilder.addService(ServerInterceptors.intercept(ServiceContainer.proxyGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); - serverBuilder.addService(ServerInterceptors.intercept(ServiceContainer.pcpGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); - return serverBuilder.build(); - } catch (SSLException e) { - throw new SecurityException(e); - } - - - } - return null; - } - - - private io.grpc.Server buildServer() { - NettyServerBuilder nettyServerBuilder = (NettyServerBuilder) ServerBuilder.forPort(MetaInfo.PROPERTY_GRPC_PORT); - nettyServerBuilder.addService(ServerInterceptors.intercept(ServiceContainer.proxyGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); - nettyServerBuilder.addService(ServerInterceptors.intercept(ServiceContainer.pcpGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); - nettyServerBuilder - .executor(Executors.newCachedThreadPool()) - .maxConcurrentCallsPerConnection(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION) - .maxInboundMessageSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE) - .maxInboundMetadataSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE) - .flowControlWindow(MetaInfo.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW); - - if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC > 0) - nettyServerBuilder.keepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC > 0) - nettyServerBuilder.keepAliveTimeout(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC > 0) { - - nettyServerBuilder.permitKeepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); - } - if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED) - nettyServerBuilder.permitKeepAliveWithoutCalls(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED); - if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC > 0) - nettyServerBuilder.maxConnectionIdle(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC > 0) - nettyServerBuilder.maxConnectionAge(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC > 0) - nettyServerBuilder.maxConnectionAgeGrace(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC, TimeUnit.SECONDS); - return nettyServerBuilder.build(); - } -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/service/UnaryCallService.java b/java/osx/broker/src/main/java/com/osx/broker/service/UnaryCallService.java deleted file mode 100644 index 3c650bd18b..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/service/UnaryCallService.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.service; -import com.osx.core.constant.ActionType; -import com.osx.core.context.Context; -import com.osx.core.exceptions.ExceptionInfo; -import com.osx.core.frame.GrpcConnectionFactory; -import com.osx.core.service.AbstractServiceAdaptor; -import com.osx.core.service.InboundPackage; -import com.webank.ai.eggroll.api.networking.proxy.DataTransferServiceGrpc; -import com.webank.ai.eggroll.api.networking.proxy.Proxy; -import io.grpc.Deadline; -import io.grpc.ManagedChannel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * 用于兼容旧版FATE - */ -public class UnaryCallService extends AbstractServiceAdaptor { - - Logger logger = LoggerFactory.getLogger(UnaryCallService.class); - - - public UnaryCallService() { - - } - - @Override - protected Proxy.Packet doService(Context context, InboundPackage data) { - context.setActionType(ActionType.UNARY_CALL.getAlias()); - Proxy.Packet req = (Proxy.Packet) data.getBody(); - Proxy.Packet resp = unaryCall(context, req); - //logger.info("uncary req {} resp {}", req, resp); - return resp; - } - - - protected Proxy.Packet transformExceptionInfo(Context context, ExceptionInfo exceptionInfo) { - return null; - } - - /** - * 非流式传输 - * - * @param context - * @param - */ - public Proxy.Packet unaryCall(Context context, Proxy.Packet req) { - Deadline endDeadline = null; - boolean isPolling = false; - - ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(context.getRouterInfo(),true); - DataTransferServiceGrpc.DataTransferServiceBlockingStub stub = DataTransferServiceGrpc.newBlockingStub(managedChannel); - Proxy.Packet result = null; - result = stub.unaryCall(req); - return result; - } - - -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/TransferUtil.java b/java/osx/broker/src/main/java/com/osx/broker/util/TransferUtil.java deleted file mode 100644 index c757c40f64..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/util/TransferUtil.java +++ /dev/null @@ -1,295 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.util; - - -import com.google.common.collect.Maps; -import com.google.protobuf.ByteString; -import com.google.protobuf.InvalidProtocolBufferException; -import com.osx.broker.eggroll.ErRollSiteHeader; -import com.osx.broker.http.HttpClientPool; -import com.osx.broker.http.PtpHttpResponse; -import com.osx.broker.queue.TransferQueue; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.Dict; -import com.osx.core.constant.Protocol; -import com.osx.core.constant.PtpHttpHeader; -import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; -import com.osx.core.exceptions.ConfigErrorException; -import com.osx.core.exceptions.NoRouterInfoException; -import com.osx.core.exceptions.RemoteRpcException; -import com.osx.core.frame.GrpcConnectionFactory; -import com.osx.core.router.RouterInfo; -import com.webank.ai.eggroll.api.networking.proxy.Proxy; -import com.webank.eggroll.core.transfer.Transfer; -import io.grpc.ManagedChannel; -import io.grpc.StatusRuntimeException; -import org.apache.commons.lang3.StringUtils; -import org.ppc.ptp.Osx; -import org.ppc.ptp.PrivateTransferProtocolGrpc; - -import javax.servlet.http.HttpServletRequest; -import java.util.Map; - -public class TransferUtil { - - /** - * 2.0之前版本 - * - * @param version - * @return - */ - public static boolean isOldVersionFate(String version) { - - try{ - if (StringUtils.isEmpty(version)) - version= MetaInfo.PROPERTY_DEFAULT_CLIENT_VERSION; - String firstVersion = version.substring(0,1); - if (Integer.parseInt(firstVersion) >= 2) { - return false; - } else { - return true; - } - }catch(NumberFormatException e){ - throw new ConfigErrorException("remote version config error : "+version); - } - - } - - - public static String buildResource(Osx.Inbound inbound){ - String sourceNodeId = inbound.getMetadataMap().get(Osx.Header.SourceNodeID.name()); - String targetNodeId = inbound.getMetadataMap().get(Osx.Header.TargetNodeID.name()); - String sourceInstId = inbound.getMetadataMap().get(Osx.Header.SourceInstID.name()); - if(sourceInstId==null){ - sourceInstId=""; - } - String targetInstId = inbound.getMetadataMap().get(Osx.Header.TargetInstID.name()); - if(targetInstId==null){ - targetInstId=""; - } - StringBuffer sb = new StringBuffer(); - sb.append(sourceInstId).append(sourceNodeId).append("_").append(targetInstId).append(targetNodeId); - return sb.toString(); - } - - public static Proxy.Metadata buildProxyMetadataFromOutbound(Osx.Outbound outbound) { - try { - return Proxy.Metadata.parseFrom(outbound.getPayload()); - } catch (InvalidProtocolBufferException e) { - - } - return null; - } - public static Osx.Outbound buildOutboundFromProxyMetadata(Proxy.Metadata metadata) { - return Osx.Outbound.newBuilder().setPayload(metadata.toByteString()).build(); - - } - - public static Proxy.Packet parsePacketFromInbound(Osx.Inbound inbound){ - try { - return Proxy.Packet.parseFrom(inbound.getPayload()); - } catch (InvalidProtocolBufferException e) { - return null; - } - } - - public static Osx.Inbound buildInboundFromPushingPacket(Proxy.Packet packet, String targetMethod) { - Osx.Inbound.Builder inboundBuilder = Osx.Inbound.newBuilder(); - Proxy.Topic srcTopic = packet.getHeader().getSrc(); - String srcPartyId = srcTopic.getPartyId(); - Proxy.Metadata metadata = packet.getHeader(); - ByteString encodedRollSiteHeader = metadata.getExt(); - ErRollSiteHeader rsHeader = null; - try { - rsHeader = ErRollSiteHeader.parseFromPb(Transfer.RollSiteHeader.parseFrom(encodedRollSiteHeader)); - } catch (InvalidProtocolBufferException e) { - e.printStackTrace(); - } - - String sessionId = ""; - if (rsHeader != null) { - sessionId = String.join("_", rsHeader.getRollSiteSessionId(), rsHeader.getDstRole(), rsHeader.getDstPartyId()); - } - Proxy.Topic desTopic = packet.getHeader().getDst(); - String desPartyId = desTopic.getPartyId(); - String desRole = desTopic.getRole(); - inboundBuilder.setPayload(packet.toByteString()); - inboundBuilder.putMetadata(Osx.Header.Version.name(), Long.toString(MetaInfo.CURRENT_VERSION)); - inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), MetaInfo.PROPERTY_FATE_TECH_PROVIDER); - inboundBuilder.putMetadata(Osx.Header.Token.name(), ""); - inboundBuilder.putMetadata(Osx.Header.SourceNodeID.name(), srcPartyId); - inboundBuilder.putMetadata(Osx.Header.TargetNodeID.name(), desPartyId); - inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), ""); - inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), ""); - inboundBuilder.putMetadata(Osx.Header.SessionID.name(), sessionId); - inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), targetMethod); - inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), desRole); - inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), ""); - return inboundBuilder.build(); - }; - - static public void buildHttpFromPb(Osx.Inbound inbound){ - - - - - } - - - static public Osx.Inbound.Builder buildPbFromHttpRequest(HttpServletRequest request){ - - Osx.Inbound.Builder inboundBuilder = Osx.Inbound.newBuilder(); - String Version = request.getHeader(PtpHttpHeader.Version); - String TechProviderCode = request.getHeader(PtpHttpHeader.TechProviderCode); - String TraceID = request.getHeader(PtpHttpHeader.TraceID); - String Token = request.getHeader(PtpHttpHeader.Token); - String SourceNodeID = request.getHeader(PtpHttpHeader.SourceNodeID); - String TargetNodeID = request.getHeader(PtpHttpHeader.TargetNodeID); - String SourceInstID = request.getHeader(PtpHttpHeader.SourceInstID); - String TargetInstID = request.getHeader(PtpHttpHeader.TargetInstID); - String SessionID = request.getHeader(PtpHttpHeader.SessionID); - String MessageTopic = request.getHeader(PtpHttpHeader.MessageTopic); - String MessageCode = request.getHeader(PtpHttpHeader.MessageCode); - String SourceComponentName = request.getHeader(PtpHttpHeader.SourceComponentName); - String TargetComponentName = request.getHeader(PtpHttpHeader.TargetComponentName); - String TargetMethod = request.getHeader(PtpHttpHeader.TargetMethod); - String MessageOffSet = request.getHeader(PtpHttpHeader.MessageOffSet); - String InstanceId = request.getHeader(PtpHttpHeader.InstanceId); - String Timestamp = request.getHeader(PtpHttpHeader.Timestamp); - - inboundBuilder.putMetadata(Osx.Header.Version.name(), Version != null ? Version : ""); - inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), TechProviderCode != null ? TechProviderCode : ""); - inboundBuilder.putMetadata(Osx.Header.Token.name(), Token != null ? Token : ""); - inboundBuilder.putMetadata(Osx.Header.SourceNodeID.name(), SourceNodeID != null ? SourceNodeID : ""); - inboundBuilder.putMetadata(Osx.Header.TargetNodeID.name(), TargetNodeID != null ? TargetNodeID : ""); - inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), SourceInstID != null ? SourceInstID : ""); - inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), TargetInstID != null ? TargetInstID : ""); - inboundBuilder.putMetadata(Osx.Header.SessionID.name(), SessionID != null ? SessionID : ""); - inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), TargetMethod != null ? TargetMethod : ""); - inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), TargetComponentName != null ? TargetComponentName : ""); - inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), SourceComponentName != null ? SourceComponentName : ""); - inboundBuilder.putMetadata(Osx.Metadata.MessageTopic.name(), MessageTopic != null ? MessageTopic : ""); - inboundBuilder.putMetadata(Osx.Metadata.MessageOffSet.name(), MessageOffSet != null ? MessageOffSet : ""); - inboundBuilder.putMetadata(Osx.Metadata.InstanceId.name(), InstanceId != null ? InstanceId : ""); - inboundBuilder.putMetadata(Osx.Metadata.Timestamp.name(), Timestamp != null ? Timestamp : ""); - return inboundBuilder; - - - } - - - - static public Osx.Outbound redirect(Context context, Osx.Inbound - produceRequest, RouterInfo routerInfo, boolean forceSend) { - Osx.Outbound result = null; - // context.setActionType("redirect"); - // 目的端协议为grpc - if (routerInfo == null) { - throw new NoRouterInfoException("can not find router info"); - } - if (routerInfo.getProtocol() == null || routerInfo.getProtocol().equals(Protocol.GRPC)) { - ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(routerInfo,true); - PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub stub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); - try { - result = stub.invoke(produceRequest); - } catch (StatusRuntimeException e) { - throw new RemoteRpcException(StatusCode.NET_ERROR, "send to " + routerInfo.toKey() + " error"); - } - // ServiceContainer.tokenApplyService.applyToken(context,routerInfo.getResource(),produceRequest.getSerializedSize()); - }else{ - if(routerInfo.getProtocol().equals(Protocol.HTTP)){ - String url = routerInfo.getUrl(); - - Map metaDataMap = produceRequest.getMetadataMap(); - - String version = metaDataMap.get(Osx.Header.Version.name()); - String techProviderCode = metaDataMap.get(Osx.Header.TechProviderCode.name()); - String traceId = metaDataMap.get(Osx.Header.TraceID.name()); - String token = metaDataMap.get(Osx.Header.Token.name()); - String sourceNodeId = metaDataMap.get(Osx.Header.SourceNodeID.name()); - String targetNodeId = metaDataMap.get(Osx.Header.TargetNodeID.name()); - String sourceInstId = metaDataMap.get(Osx.Header.SourceInstID.name()); - String targetInstId = metaDataMap.get(Osx.Header.TargetInstID.name()); - String sessionId = metaDataMap.get(Osx.Header.SessionID.name()); - String targetMethod = metaDataMap.get(Osx.Metadata.TargetMethod.name()); - String targetComponentName = metaDataMap.get(Osx.Metadata.TargetComponentName.name()); - String sourceComponentName = metaDataMap.get(Osx.Metadata.SourceComponentName.name()); - String sourcePartyId = StringUtils.isEmpty(sourceInstId) ? sourceNodeId : sourceInstId + "." + sourceNodeId; - String targetPartyId = StringUtils.isEmpty(targetInstId) ? targetNodeId : targetInstId + "." + targetNodeId; - String topic = metaDataMap.get(Osx.Metadata.MessageTopic.name()); - String offsetString = metaDataMap.get(Osx.Metadata.MessageOffSet.name()); - String InstanceId = metaDataMap.get(Osx.Metadata.InstanceId.name()); - String timestamp = metaDataMap.get(Osx.Metadata.Timestamp.name()); - String messageCode = metaDataMap.get(Osx.Metadata.MessageCode.name()); - Map header = Maps.newHashMap(); - header.put(PtpHttpHeader.Version,version!=null?version:""); - header.put(PtpHttpHeader.TechProviderCode,techProviderCode!=null?techProviderCode:""); - header.put(PtpHttpHeader.TraceID,traceId!=null?traceId:""); - header.put(PtpHttpHeader.Token,token!=null?token:""); - header.put(PtpHttpHeader.SourceNodeID,sourceNodeId!=null?sourceNodeId:""); - header.put(PtpHttpHeader.TargetNodeID,targetNodeId!=null?targetNodeId:""); - header.put(PtpHttpHeader.SourceInstID,sourceInstId!=null?sourceInstId:""); - header.put(PtpHttpHeader.TargetInstID,targetInstId!=null?targetInstId:""); - header.put(PtpHttpHeader.SessionID,sessionId!=null?sessionId:""); - header.put(PtpHttpHeader.MessageTopic,topic!=null?topic:""); - header.put(PtpHttpHeader.MessageCode,messageCode); - header.put(PtpHttpHeader.SourceComponentName,sourceComponentName!=null?sourceComponentName:""); - header.put(PtpHttpHeader.TargetComponentName,targetComponentName!=null?targetComponentName:""); - header.put(PtpHttpHeader.TargetMethod,targetMethod!=null?targetMethod:""); - header.put(PtpHttpHeader.MessageOffSet,offsetString!=null?offsetString:""); - header.put(PtpHttpHeader.InstanceId,InstanceId!=null?InstanceId:""); - header.put(PtpHttpHeader.Timestamp,timestamp!=null?timestamp:""); - result = HttpClientPool.sendPtpPost(url,produceRequest.getPayload().toByteArray(),header); - } - } - - return result; - - } - - - public static Osx.Outbound buildResponse(String code, String msgReturn, TransferQueue.TransferQueueConsumeResult messageWraper) { - // FireworkTransfer.ConsumeResponse.Builder consumeResponseBuilder = FireworkTransfer.ConsumeResponse.newBuilder(); - Osx.Outbound.Builder builder = Osx.Outbound.newBuilder(); - - builder.setCode(code); - builder.setMessage(msgReturn); - if (messageWraper != null) { - Osx.Message message = null; - try { - message = Osx.Message.parseFrom(messageWraper.getMessage().getBody()); - } catch (InvalidProtocolBufferException e) { - e.printStackTrace(); - } - builder.setPayload(message.toByteString()); - builder.putMetadata(Osx.Metadata.MessageOffSet.name(), Long.toString(messageWraper.getRequestIndex())); -// FireworkTransfer.Message msg = produceRequest.getMessage(); -// consumeResponseBuilder.setTransferId(produceRequest.getTransferId()); -// consumeResponseBuilder.setMessage(msg); -// consumeResponseBuilder.setStartOffset(messageWraper.getRequestIndex()); -// consumeResponseBuilder.setTotalOffset(messageWraper.getLogicIndexTotal()); - } - - return builder.build(); - } - - - public static void main(String[] args){ - System.err.println(isOldVersionFate(null)); - } -} diff --git a/java/osx/broker/src/main/resources/broker.properties b/java/osx/broker/src/main/resources/broker.properties deleted file mode 100644 index c288094306..0000000000 --- a/java/osx/broker/src/main/resources/broker.properties +++ /dev/null @@ -1,23 +0,0 @@ -#grpc?? -grpc.port= 9370 -#????http server -open.http.server=false -# http?? -http.port=8080 -# ????grpc+TLS?? -open.grpc.tls.server=false -#grpc+TLS???????? -grpc.tls.port=9883 -#??partyId,???????????? -self.party=10000 -#???? standalone/cluster?standalone?????? cluster?????? -deploy.model=standalone -#?????????zookeeper,???zookeeper?? -zk.url=localhost:2181 -#????eggroll???????????eggroll cluster-manager???ip??? -eggroll.cluster.manager.ip = localhost -eggroll.cluster.manager.port = 4670 - - - - diff --git a/java/osx/broker/src/main/resources/route_table.json b/java/osx/broker/src/main/resources/route_table.json deleted file mode 100755 index b34487d85d..0000000000 --- a/java/osx/broker/src/main/resources/route_table.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "route_table": - { - "9999": - { - "default":[ - { - "port": 9371, - "ip": "localhost" - } - ], - "fateflow":[ - { - "port": 9360, - "ip": "localhost" - } - ] - }, - "10000":{ - "default":[{ - "port": 9889, - "ip": "localhost" - }] - - } - }, - "permission": - { - "default_allow": true - } -} diff --git a/java/osx/core/src/main/java/com/osx/core/config/MetaInfo.java b/java/osx/core/src/main/java/com/osx/core/config/MetaInfo.java deleted file mode 100644 index 1ff1d88779..0000000000 --- a/java/osx/core/src/main/java/com/osx/core/config/MetaInfo.java +++ /dev/null @@ -1,230 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.osx.core.config; - -import com.google.common.collect.Maps; -import com.google.common.collect.Sets; -import com.osx.core.constant.DeployMode; -import com.osx.core.constant.Dict; -import com.osx.core.constant.StreamLimitMode; - -import java.lang.reflect.Field; -import java.util.Map; -import java.util.Set; - -public class MetaInfo { - public static final long CURRENT_VERSION = 100; - public static String PROPERTY_FATE_TECH_PROVIDER = "FATE"; - public static String PROPERTY_DEFAULT_CLIENT_VERSION="2.X.X"; - public static volatile MasterInfo masterInfo; - public static int PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION = 1000; - public static int PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE = 128 << 20; - public static int PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE = (2 << 30) - 1; - public static int PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW = 128 << 20; - public static int PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC = 7200; - public static int PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC = 3600; - public static int PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC = 10; - public static boolean PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED = true; - public static int PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC = 86400; - public static int PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC = 86400; - public static int PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC = 86400; - public static int PROPERTY_GRPC_ONCOMPLETED_WAIT_TIMEOUT = 600; - - - - - public static int PROPERTY_GRPC_CLIENT_MAX_CONCURRENT_CALL_PER_CONNECTION = 1000; - public static int PROPERTY_GRPC_CLIENT_MAX_INBOUND_METADATA_SIZE = 128 << 20; - public static int PROPERTY_GRPC_CLIENT_MAX_INBOUND_MESSAGE_SIZE = (2 << 30) - 1; - public static int PROPERTY_GRPC_CLIENT_FLOW_CONTROL_WINDOW = 128 << 20; - public static int PROPERTY_GRPC_CLIENT_KEEPALIVE_TIME_SEC = 7200; - public static int PROPERTY_GRPC_CLIENT_KEEPALIVE_TIMEOUT_SEC = 3600; - public static int PROPERTY_GRPC_CLIENT_PERMIT_KEEPALIVE_TIME_SEC = 10; - public static boolean PROPERTY_GRPC_CLIENT_KEEPALIVE_WITHOUT_CALLS_ENABLED = true; - public static int PROPERTY_GRPC_CLIENT_MAX_CONNECTION_IDLE_SEC = 86400; - public static int PROPERTY_GRPC_CLIENT_MAX_CONNECTION_AGE_SEC = 86400; - public static int PROPERTY_GRPC_CLIENT_MAX_CONNECTION_AGE_GRACE_SEC = 86400; - public static int PROPERTY_GRPC_CLIENT_PER_RPC_BUFFER_LIMIT=86400; - - public static int PROPERTY_GRPC_CLIENT_RETRY_BUFFER_SIZE = 86400; - - - - public static boolean PROPERTY_USE_DIRECT_CACHE = false; - public static int PROPERTY_TRANSFER_FILE_CACHE_SIZE = 1 << 27; - public static int PROPERTY_TRANSFER_RETRY_COUNT = 1; - public static int MAP_FILE_SIZE = 1 << 25; - public static int PROPERTY_INDEX_MAP_FILE_SIZE = 1 << 21; - public static Boolean TRANSFER_FATECLOUD_AHTHENTICATION_ENABLED; - public static Boolean TRANSFER_FATECLOUD_AUTHENTICATION_USE_CONFIG; - public static String TRANSFER_FATECLOUD_AUTHENTICATION_URI; - public static String TRANSFER_FATECLOUD_AUTHENTICATION_APPKEY; - public static String TRANSFER_FATECLOUD_AUTHENTICATION_APPSERCRET; - public static String TRANSFER_FATECLOUD_AUTHENTICATION_ROLE; - public static String TRANSFER_FATECLOUD_SECRET_INFO_URL; - public static String TRANSFER_FATECLOUD_AUTHENTICATION_URL; - public static String PROPERTY_SERVER_CERTCHAIN_FILE; - public static String PROPERTY_SERVER_PRIVATEKEY_FILE; - public static String PROPERTY_SERVER_CA_FILE; - public static int ROLLSITE_PARTY_ID; -// public static Integer PROPERTY_PORT; - public static Integer PROPERTY_GRPC_PORT; - public static Integer PROPERTY_HTTP_PORT; - public static Boolean PROPERTY_OPEN_HTTP_SERVER = false; - public static Boolean PROPERTY_OPEN_GRPC_TLS_SERVER = false; - public static int PROPERTY_HTTP_REQUEST_BODY_MAX_SIZE=4096; - public static String PROPERTY_HTTP_CONTEXT_PATH="/osx"; - public static String PROPERTY_HTTP_SERVLET_PATH="/*"; - public static Integer PROPERTY_GRPC_TLS_PORT; - public static String PROPERTY_ZK_URL; - public static Boolean PROPERTY_USE_DISRUPTOR = true; - public static int PROPERTY_STREAM_LIMIT_MAX_TRY_TIME = 3; - - public static String PROPERTY_USER_HOME = ""; - - public static Integer PROPERTY_SAMPLE_COUNT = 10; - public static Integer PROPERTY_INTERVAL_MS = 1000; - //public static Boolean PROPERTY_USE_QUEUE_MODEL = false; - public static String PROPERTY_STREAM_LIMIT_MODE = StreamLimitMode.NOLIMIT.name(); - - public static Integer PROPERTY_CONSUMER_TIMEOUT = 30000; - public static Integer PROPERTY_QUEUE_MAX_FREE_TIME; - public static Integer PROPERTY_MAPPED_FILE_EXPIRE_TIME = 3600 * 1000 * 36; - public static Integer PROPERTY_MAX_CONSUME_EMPTY_TRY_COUNT = 30; - - public static Integer PROPERTY_MAX_TRANSFER_CACHE_SIZE = 1 << 30; - public static String PROPERTY_TRANSFER_FILE_PATH_PRE; - public static String PROPERTY_DEPLOY_MODE = "standalone"; - public static String PROPERTY_TRANSFER_APPLY_CACHE = "/tmp/cachetest"; - - public static Set PROPERTY_SELF_PARTY = Sets.newHashSet();// - - public static Integer PROPERTY_APPLY_EXPIRE_TIME = 3000; - public static Integer PROPERTY_COORDINATOR; - public static Integer PROPERTY_SERVER_PORT; - public static String PROPERTY_INFERENCE_SERVICE_NAME; - public static String PROPERTY_ROUTE_TYPE; - public static String PROPERTY_ROUTE_TABLE; - - public static String PROPERTY_FLOW_RULE_TABLE; - public static String PROPERTY_AUTH_FILE; - public static Boolean PROPERTY_ACL_ENABLE = false; - public static String PROPERTY_ACL_USERNAME; - public static String PROPERTY_ACL_PASSWORD; - public static String PROPERTY_ROOT_PATH; - public static Boolean PROPERTY_PRINT_INPUT_DATA; - public static Boolean PROPERTY_PRINT_OUTPUT_DATA; - - public static Boolean PROPERTY_AUTH_OPEN; - public static String PROPERTY_NEGOTIATIONTYPE; - public static String PROPERTY_PROXY_GRPC_INTER_CA_FILE; - public static String PROPERTY_PROXY_GRPC_INTER_CLIENT_CERTCHAIN_FILE; - public static String PROPERTY_PROXY_GRPC_INTER_CLIENT_PRIVATEKEY_FILE; - public static String PROPERTY_PROXY_GRPC_INTER_SERVER_CERTCHAIN_FILE; - public static String PROPERTY_PROXY_GRPC_INTER_SERVER_PRIVATEKEY_FILE; - public static Integer PROPERTY_ADMIN_HEALTH_CHECK_TIME; - public static Integer PRPPERTY_QUEUE_MAX_FREE_TIME; - public static String ROLLSITE_ROUTE_TABLE_KEY; - public static String ROLLSITE_ROUTE_TABLE_WHITE_LIST; - public static String ROLLSITE_ROUTE_TABLE_PARTY_ID; - public static String INSTANCE_ID; - - public static String PROPERTY_EGGROLL_CLUSTER_MANANGER_IP; - public static Integer PROPERTY_EGGROLL_CLUSTER_MANANGER_PORT; - - - public static Integer PROPERTY_CONSUME_SPIN_TIME = 500; - - public static String PROPERTY_CLUSTER_MANAGER_ADDRESS; - public static Integer PROPERTY_NETTY_CLIENT_TIMEOUT = 3000; - - public static Integer PROPERTY_HEARTBEAT_INTERVAL = 10000; - - public static String PROPERTY_CLUSTER_MANAGER_HOST; - public static Integer PROPERTY_CLUSTER_MANAGER_PORT; - - public static Boolean PROPERTY_USE_ZOOKEEPER = true; - - /** - * 从连接池中申请连接的超时时间 - */ - public static Integer HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT; - /** - * 建立连接的超时时间 - */ - public static Integer HTTP_CLIENT_CONFIG_CONN_TIME_OUT; - /** - * 等待数据 - */ - public static Integer HTTP_CLIENT_CONFIG_SOCK_TIME_OUT; - public static Integer HTTP_CLIENT_INIT_POOL_MAX_TOTAL; - public static Integer HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE; - public static Integer HTTP_CLIENT_INIT_POOL_SOCK_TIME_OUT; - public static Integer HTTP_CLIENT_INIT_POOL_CONN_TIME_OUT; - public static Integer HTTP_CLIENT_INIT_POOL_CONN_REQ_TIME_OUT; - public static Integer HTTP_CLIENT_TRAN_CONN_REQ_TIME_OUT; - public static Integer HTTP_CLIENT_TRAN_CONN_TIME_OUT; - public static Integer HTTP_CLIENT_TRAN_SOCK_TIME_OUT; - - - - - - - public static String getClusterManagerHost() { - if (PROPERTY_CLUSTER_MANAGER_HOST != null) { - return PROPERTY_CLUSTER_MANAGER_HOST; - } else { - PROPERTY_CLUSTER_MANAGER_HOST = PROPERTY_CLUSTER_MANAGER_ADDRESS.split(":")[0]; - PROPERTY_CLUSTER_MANAGER_PORT = Integer.parseInt(PROPERTY_CLUSTER_MANAGER_ADDRESS.split(":")[1]); - return PROPERTY_CLUSTER_MANAGER_HOST; - } - } - - public static Integer getClusterManagerPort() { - if (PROPERTY_CLUSTER_MANAGER_PORT != null) { - return PROPERTY_CLUSTER_MANAGER_PORT; - } else { - PROPERTY_CLUSTER_MANAGER_HOST = PROPERTY_CLUSTER_MANAGER_ADDRESS.split(":")[0]; - PROPERTY_CLUSTER_MANAGER_PORT = Integer.parseInt(PROPERTY_CLUSTER_MANAGER_ADDRESS.split(":")[1]); - return PROPERTY_CLUSTER_MANAGER_PORT; - } - } - - - public static boolean isCluster() { - return PROPERTY_DEPLOY_MODE.equals(DeployMode.cluster.name()); - } - - public static Map toMap() { - Map result = Maps.newHashMap(); - Field[] fields = MetaInfo.class.getFields(); - - for (Field field : fields) { - try { - if (field.get(MetaInfo.class) != null) { - String key = Dict.class.getField(field.getName()) != null ? String.valueOf(Dict.class.getField(field.getName()).get(Dict.class)) : field.getName(); - result.put(key, field.get(MetaInfo.class)); - } - } catch (IllegalAccessException | NoSuchFieldException e) { - - } - } - return result; - } - -} diff --git a/java/osx/core/src/main/java/com/osx/core/router/RouterInfo.java b/java/osx/core/src/main/java/com/osx/core/router/RouterInfo.java deleted file mode 100644 index 0ca0a0b1e3..0000000000 --- a/java/osx/core/src/main/java/com/osx/core/router/RouterInfo.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.osx.core.router; - -import com.osx.core.constant.Protocol; -import com.osx.core.utils.JsonUtil; -import lombok.Data; - -@Data -public class RouterInfo { - private Protocol protocol; - private String sourcePartyId; - private String desPartyId; - private String desMode; - private String url; - private String host; - private Integer port; - private boolean useSSL = false; - private String negotiationType; - private String certChainFile; - private String privateKeyFile; - private String trustCertCollectionFile; - private String caFile; - private String version; - - public String toKey() { - StringBuffer sb = new StringBuffer(); - sb.append(host).append("_").append(port); - if(negotiationType!=null) - sb.append("_").append(negotiationType); - return sb.toString(); - } - - @Override - public String toString() { - return JsonUtil.object2Json(this); - } - - public String getResource() { - StringBuilder sb = new StringBuilder(); - sb.append(sourcePartyId).append("-").append(desPartyId); - return sb.toString(); - } - - -} \ No newline at end of file diff --git a/java/osx/core/src/main/java/com/osx/core/utils/FlowLogUtil.java b/java/osx/core/src/main/java/com/osx/core/utils/FlowLogUtil.java deleted file mode 100644 index 241b85544a..0000000000 --- a/java/osx/core/src/main/java/com/osx/core/utils/FlowLogUtil.java +++ /dev/null @@ -1,54 +0,0 @@ -package com.osx.core.utils; - -import com.osx.core.context.Context; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class FlowLogUtil { - static Logger logger = LoggerFactory.getLogger("flow"); - static final String SPLIT= "|"; - public static void printFlowLog(Context context) { - StringBuffer stringBuffer = new StringBuffer(); - if(context.getActionType()!=null){ - stringBuffer.append(context.getActionType()).append(SPLIT); - } - if(context.getSessionId()!=null){ - stringBuffer.append("session:").append(context.getSessionId()).append(SPLIT); - } - if(context.getTopic()!=null){ - stringBuffer.append("topic:").append(context.getTopic()).append(SPLIT); - } - if(context.getRequestMsgIndex()!=null){ - stringBuffer.append("req-offset:").append(context.getRequestMsgIndex()).append(SPLIT); - } - if(context.getCurrentMsgIndex()!=null){ - stringBuffer.append("offset-in-queue:").append(context.getCurrentMsgIndex()).append(SPLIT); - } - if(context.getSrcPartyId()!=null){ - stringBuffer.append("src:").append(context.getSrcPartyId()).append(SPLIT); - } - if(context.getDesPartyId()!=null){ - stringBuffer.append("des:").append(context.getDesPartyId()).append(SPLIT); - } - if(context.getReturnCode()!=null){ - stringBuffer.append("code:").append(context.getReturnCode()).append(SPLIT); - } - stringBuffer.append("cost:").append(System.currentTimeMillis() - context.getTimeStamp()).append(SPLIT); - if(context.getRouterInfo()!=null){ - stringBuffer.append("router_info:").append(context.getRouterInfo().getHost() + ":" + context.getRouterInfo().getPort()).append(SPLIT); - } - if(context.getDataSize()!=null){ - stringBuffer.append("size:").append(context.getDataSize()).append(SPLIT); - } - if(context.getReturnMsg()!=null){ - stringBuffer.append("msg:").append(context.getReturnMsg()); - } - logger.info(stringBuffer.toString()); - - } - - - - - -} diff --git a/java/osx/deploy/auto-package.sh b/java/osx/deploy/auto-package.sh index 02dc1fc4ad..17bc921574 100755 --- a/java/osx/deploy/auto-package.sh +++ b/java/osx/deploy/auto-package.sh @@ -9,21 +9,21 @@ fi mkdir osx/bin mkdir osx/lib mkdir osx/conf +mkdir osx/extension mkdir osx/conf/broker -#mkdir osx/conf/cluster-manager +mkdir osx/conf/components cd .. mvn clean package -DskipTests - if [[ ! -d "lib" ]]; then mkdir lib fi - -cp -r broker/target/*.jar deploy/osx/lib -cp -r broker/target/lib/* deploy/osx/lib -cp broker/src/main/resources/* deploy/osx/conf/broker +cp -r osx-broker/target/*.jar deploy/osx/lib +cp -r osx-broker/target/lib/* deploy/osx/lib +cp osx-broker/src/main/resources/broker/* deploy/osx/conf/broker +cp -r osx-broker/src/main/resources/components/* deploy/osx/conf/components cp bin/service.sh deploy/osx/ cp bin/common.sh deploy/osx/bin cd deploy diff --git a/java/osx/osx-api/pom.xml b/java/osx/osx-api/pom.xml new file mode 100644 index 0000000000..4dfb72cd89 --- /dev/null +++ b/java/osx/osx-api/pom.xml @@ -0,0 +1,59 @@ + + + + osx + osx + ${osx.version} + + 4.0.0 + + osx-api + + + 8 + 8 + + + + io.grpc + grpc-netty-shaded + + + io.grpc + grpc-protobuf + + + io.grpc + grpc-stub + + + org.eclipse.jetty + jetty-server + + + + org.eclipse.jetty + jetty-servlet + + + + + + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + + + + + + + + + + \ No newline at end of file diff --git a/java/osx/core/src/main/java/com/osx/core/constant/Protocol.java b/java/osx/osx-api/src/main/java/com/osx/api/constants/Protocol.java similarity index 92% rename from java/osx/core/src/main/java/com/osx/core/constant/Protocol.java rename to java/osx/osx-api/src/main/java/com/osx/api/constants/Protocol.java index 189b447630..768bfe1f70 100644 --- a/java/osx/core/src/main/java/com/osx/core/constant/Protocol.java +++ b/java/osx/osx-api/src/main/java/com/osx/api/constants/Protocol.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.constant; +package com.osx.api.constants; public enum Protocol { - GRPC, HTTP + grpc, + http } diff --git a/java/osx/osx-api/src/main/java/com/osx/api/context/Context.java b/java/osx/osx-api/src/main/java/com/osx/api/context/Context.java new file mode 100644 index 0000000000..34cd98a5d4 --- /dev/null +++ b/java/osx/osx-api/src/main/java/com/osx/api/context/Context.java @@ -0,0 +1,45 @@ +package com.osx.api.context; + +import com.osx.api.constants.Protocol; +import com.osx.api.router.RouterInfo; + +public interface Context { + public String getTechProviderCode() ; + public void setTechProviderCode(String techProviderCode) ; + public String getTraceId() ; + public void setTraceId(String traceId); + public void setJobId(String jobId); + public String getToken() ; + public void setToken(String token) ; + public String getTopic(); + public void setTopic(String topic); + public Protocol getProtocol(); + public void setProtocol(Protocol protocol); + public String getSessionId() ; + public void setSessionId(String sessionId); + public Object getData(Object key); + public void putData(Object key, Object data); + public String getSrcPartyId() ; + public void setSrcPartyId(String guestAppId) ; + public String getDesPartyId() ; + public void setDesPartyId(String hostAppid) ; + public void setSrcComponent(String srcComponent); + public String getSrcComponent(); + public void setDesComponent(String desComponent); + public String getDesComponent(); + public String getReturnCode() ; + public void setReturnCode(String returnCode); + public String getReturnMsg() ; + public void setReturnMsg(String returnMsg); + public String getServiceName(); + public void setServiceName(String serviceName) ; + public String getSelfPartyId(); + public void setSelfPartyId(String selfPartyId); + public void setActionType(String actionType); + public String getActionType(); + public void setRouterInfo(RouterInfo routerInfo); + public RouterInfo getRouterInfo(); + public Context subContext(); + + +} diff --git a/java/osx/osx-api/src/main/java/com/osx/api/router/RouterInfo.java b/java/osx/osx-api/src/main/java/com/osx/api/router/RouterInfo.java new file mode 100644 index 0000000000..5d239e100b --- /dev/null +++ b/java/osx/osx-api/src/main/java/com/osx/api/router/RouterInfo.java @@ -0,0 +1,193 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.osx.api.router; +import com.osx.api.constants.Protocol; + + + +public class RouterInfo { + private Protocol protocol; + private String sourcePartyId; + private String desPartyId; + private String desRole; + private String sourceRole; + private String url; + private String host; + private Integer port; + private boolean useSSL = false; + private String negotiationType; + private String certChainFile; + private String privateKeyFile; + private String trustCertCollectionFile; + private String caFile; + private String version; + + public Protocol getProtocol() { + return protocol; + } + + public void setProtocol(Protocol protocol) { + this.protocol = protocol; + } + + public String getSourcePartyId() { + return sourcePartyId; + } + + public void setSourcePartyId(String sourcePartyId) { + this.sourcePartyId = sourcePartyId; + } + + public String getDesPartyId() { + return desPartyId; + } + + public void setDesPartyId(String desPartyId) { + this.desPartyId = desPartyId; + } + + public String getDesRole() { + return desRole; + } + + public void setDesRole(String desRole) { + this.desRole = desRole; + } + + public String getSourceRole() { + return sourceRole; + } + + public void setSourceRole(String sourceRole) { + this.sourceRole = sourceRole; + } + + public String getUrl() { + return url; + } + + public void setUrl(String url) { + this.url = url; + } + + public String getHost() { + return host; + } + + public void setHost(String host) { + this.host = host; + } + + public Integer getPort() { + return port; + } + + public void setPort(Integer port) { + this.port = port; + } + + public boolean isUseSSL() { + return useSSL; + } + + public void setUseSSL(boolean useSSL) { + this.useSSL = useSSL; + } + + public String getNegotiationType() { + return negotiationType; + } + + public void setNegotiationType(String negotiationType) { + this.negotiationType = negotiationType; + } + + public String getCertChainFile() { + return certChainFile; + } + + public void setCertChainFile(String certChainFile) { + this.certChainFile = certChainFile; + } + + public String getPrivateKeyFile() { + return privateKeyFile; + } + + public void setPrivateKeyFile(String privateKeyFile) { + this.privateKeyFile = privateKeyFile; + } + + public String getTrustCertCollectionFile() { + return trustCertCollectionFile; + } + + public void setTrustCertCollectionFile(String trustCertCollectionFile) { + this.trustCertCollectionFile = trustCertCollectionFile; + } + + public String getCaFile() { + return caFile; + } + + public void setCaFile(String caFile) { + this.caFile = caFile; + } + + public String getVersion() { + return version; + } + + public void setVersion(String version) { + this.version = version; + } + + public boolean isCycle() { + return isCycle; + } + + public void setCycle(boolean cycle) { + isCycle = cycle; + } + + private boolean isCycle; + + public String toKey() { + StringBuffer sb = new StringBuffer(); + if(Protocol.grpc.equals(protocol)) { + sb.append(host).append("_").append(port); + if (negotiationType != null) + sb.append("_").append(negotiationType); + }else { + sb.append(url); + } + return sb.toString(); + } + + @Override + public String toString() { + return toKey(); + } + + public String getResource() { + StringBuilder sb = new StringBuilder(); + sb.append(sourcePartyId).append("-").append(desPartyId); + return sb.toString(); + } + + +} \ No newline at end of file diff --git a/java/osx/core/src/main/java/com/osx/core/provider/TechProvider.java b/java/osx/osx-api/src/main/java/com/osx/api/tech/provider/TechProvider.java similarity index 84% rename from java/osx/core/src/main/java/com/osx/core/provider/TechProvider.java rename to java/osx/osx-api/src/main/java/com/osx/api/tech/provider/TechProvider.java index 20ad62b8c4..efe2cfc291 100644 --- a/java/osx/core/src/main/java/com/osx/core/provider/TechProvider.java +++ b/java/osx/osx-api/src/main/java/com/osx/api/tech/provider/TechProvider.java @@ -13,24 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.provider; - +package com.osx.api.tech.provider; import io.grpc.stub.StreamObserver; import org.ppc.ptp.Osx; - - import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; public interface TechProvider { - + //用于处理http1.X请求 void processHttpInvoke(HttpServletRequest httpServletRequest,HttpServletResponse httpServletResponse); - + //用于处理grpc非流式请求 void processGrpcInvoke(Osx.Inbound request, io.grpc.stub.StreamObserver responseObserver); - - String getProviderId(); - - public StreamObserver processGrpcTransport(Osx.Inbound inbound, io.grpc.stub.StreamObserver responseObserver); + //用于处理grpc流式请求 + public StreamObserver processGrpcTransport(Osx.Inbound inbound, StreamObserver responseObserver); } diff --git a/java/osx/osx-api/src/main/java/com/osx/api/translator/Translator.java b/java/osx/osx-api/src/main/java/com/osx/api/translator/Translator.java new file mode 100644 index 0000000000..8e01381ad5 --- /dev/null +++ b/java/osx/osx-api/src/main/java/com/osx/api/translator/Translator.java @@ -0,0 +1,16 @@ +package com.osx.api.translator; + + +import com.osx.api.context.Context; +import org.ppc.ptp.Osx; +//用于转换不同厂商通信时的接收和发总数据, +public interface Translator { + //服务方转化接收的数据 + Osx.Inbound translateReceiveInbound(Context context, Osx.Inbound inbound); + //请求方转化接受到的返回数据 + Osx.Outbound translateReceiveOutbound(Context context,Osx.Outbound outbound); + //请求方转化发送的数据 + Osx.Inbound translateSendInbound(Context context,Osx.Inbound inbound); + //服务方转化准备返回的数据 + Osx.Outbound translateSendOutbound(Context context,Osx.Outbound outbound); +} diff --git a/java/osx/broker/package.xml b/java/osx/osx-broker/package.xml similarity index 73% rename from java/osx/broker/package.xml rename to java/osx/osx-broker/package.xml index 35387a9715..f05abac5cb 100644 --- a/java/osx/broker/package.xml +++ b/java/osx/osx-broker/package.xml @@ -27,7 +27,7 @@ - /lib + /osx target *.jar @@ -37,7 +37,7 @@ - /lib + /osx/lib target/lib *.jar @@ -47,7 +47,7 @@ - / + /osx/ bin service.sh @@ -56,7 +56,7 @@ unix - /bin + /osx/bin bin transfer.sh @@ -65,7 +65,7 @@ unix - /bin + /osx/bin ../bin *.sh @@ -75,22 +75,28 @@ - /conf - src/main/resources + /osx/conf + ../build - transfer.properties - route_table.json + * - /conf - src/main/resources + /osx/conf/broker + ../build/broker - log4j2.xml + *.* + + /osx/conf/components + ../build/components + + *.* + + \ No newline at end of file diff --git a/java/osx/broker/pom.xml b/java/osx/osx-broker/pom.xml similarity index 84% rename from java/osx/broker/pom.xml rename to java/osx/osx-broker/pom.xml index 19e507635d..9092bbf33e 100644 --- a/java/osx/broker/pom.xml +++ b/java/osx/osx-broker/pom.xml @@ -9,27 +9,27 @@ 4.0.0 - broker + osx-broker osx - core + osx-core ${osx.version} - org.eclipse.jetty - jetty-server + org.eclipse.jetty + jetty-server com.google.guava guava - - - - + + com.lmax + disruptor + org.apache.commons commons-lang3 @@ -58,7 +58,10 @@ io.grpc grpc-stub - + + commons-net + commons-net + org.apache.curator curator-recipes @@ -106,18 +109,6 @@ - - org.junit.platform - junit-platform-launcher - 1.0.1 - test - - - org.junit.jupiter - junit-jupiter-engine - 5.0.1 - test - org.junit.vintage junit-vintage-engine @@ -134,13 +125,13 @@ - net.java.dev.jna - jna + net.java.dev.jna + jna - commons-validator - commons-validator + commons-validator + commons-validator @@ -171,5 +162,4 @@ - \ No newline at end of file diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/Bootstrap.java b/java/osx/osx-broker/src/main/java/com/osx/broker/Bootstrap.java new file mode 100644 index 0000000000..ac20c6e7a8 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/Bootstrap.java @@ -0,0 +1,91 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.osx.broker; +import com.osx.core.config.MetaInfo; +import com.osx.core.jvm.JvmInfoCounter; +import com.osx.core.utils.PropertiesUtil; +import com.osx.core.utils.ServerUtil; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Option; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.PosixParser; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.util.Properties; +public class Bootstrap { + static Logger logger = LoggerFactory.getLogger(Bootstrap.class); + static CommandLine commandLine; + static Object lockObject= new Object(); + public static void main(String[] args) { + try { + Options options = ServerUtil.buildCommandlineOptions(new Options()); + commandLine = ServerUtil.parseCmdLine("osx", args, buildCommandlineOptions(options), + new PosixParser()); + String configDir = commandLine.getOptionValue('c'); + logger.info("try to parse config dir {}", configDir); + if (StringUtils.isEmpty(configDir)) { + System.err.println("config file is not set ,please use -c to set the config file dir path"); + System.exit(-1); + } + parseConfig(configDir); + Bootstrap bootstrap = new Bootstrap(); + bootstrap.start(args); + Thread shutDownThread = new Thread(bootstrap::stop); + Runtime.getRuntime().addShutdownHook(shutDownThread); + synchronized (lockObject){ + lockObject.wait(); + } + + } catch (Exception ex) { + logger.error("broker start failed ",ex); + ex.printStackTrace(); + System.exit(1); + } + } + + private static Options buildCommandlineOptions(final Options options) { + Option opt = new Option("c", "configFile", true, "config properties file"); + opt.setRequired(false); + options.addOption(opt); + return options; + } + + public static void parseConfig(String configDir) { + try { + MetaInfo.PROPERTY_CONFIG_DIR = configDir; + String configFilePath = configDir+ "/broker/broker.properties"; + Properties environment = PropertiesUtil.getProperties(configFilePath); + MetaInfo.init(environment); + } catch (Exception e) { + logger.error("init MetaInfo error", e); + System.exit(1); + } + } + + public void start(String[] args) { + ServiceContainer.init(); + JvmInfoCounter.start(); + } + + public void stop() { + logger.info("try to shutdown server ..."); + if (ServiceContainer.transferQueueManager != null) { + ServiceContainer.transferQueueManager.destroyAll(); + } + } + +} \ No newline at end of file diff --git a/java/osx/broker/src/main/java/com/osx/broker/ServiceContainer.java b/java/osx/osx-broker/src/main/java/com/osx/broker/ServiceContainer.java similarity index 62% rename from java/osx/broker/src/main/java/com/osx/broker/ServiceContainer.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/ServiceContainer.java index 0ec60a5397..6f49c3a386 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/ServiceContainer.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/ServiceContainer.java @@ -17,14 +17,19 @@ import com.osx.broker.consumer.ConsumerManager; +import com.osx.broker.eggroll.EventDriverMsgManager; import com.osx.broker.grpc.PcpGrpcService; import com.osx.broker.grpc.ProxyGrpcService; -import com.osx.broker.interceptor.RequestHandleInterceptor; +import com.osx.broker.http.HttpClientPool; + import com.osx.broker.interceptor.RouterInterceptor; import com.osx.broker.message.AllocateMappedFileService; import com.osx.broker.queue.TransferQueueManager; import com.osx.broker.router.DefaultFateRouterServiceImpl; import com.osx.broker.router.FateRouterService; +import com.osx.broker.router.RouterRegister; +import com.osx.broker.security.TokenGeneratorRegister; +import com.osx.broker.security.TokenValidatorRegister; import com.osx.broker.server.OsxServer; import com.osx.broker.service.PushService; import com.osx.broker.service.TokenApplyService; @@ -47,63 +52,95 @@ public class ServiceContainer { static public ConsumerManager consumerManager; - static public PcpGrpcService pcpGrpcService; static public TransferQueueManager transferQueueManager; - static public AllocateMappedFileService allocateMappedFileService; static public FlowCounterManager flowCounterManager; static public OsxServer transferServer; - static public ProxyGrpcService proxyGrpcService; - static public FateRouterService fateRouterService; static public Map serviceAdaptorMap = new HashMap(); static public TokenApplyService tokenApplyService; - static public PushService pushService; - static public UnaryCallService unaryCallService; - static public RequestHandleInterceptor requestHandleInterceptor; - static public MessageStore messageStore; - static public RouterInterceptor routerInterceptor; static public ClusterFlowRuleManager clusterFlowRuleManager; static public DefaultTokenService defaultTokenService; static public CuratorZookeeperClient zkClient; + //厂商注册 static public TechProviderRegister techProviderRegister; + static public EventDriverMsgManager eventDriverMsgManager; + //Token校验器,用于双方token校验 + static public TokenValidatorRegister tokenValidatorRegister; + //Token生成器注册,用于双方token校验 + static public TokenGeneratorRegister tokenGeneratorRegister; + + static public RouterRegister routerRegister; + static Logger logger = LoggerFactory.getLogger(ServiceContainer.class); public static void init() { flowCounterManager = createFlowCounterManager(); clusterFlowRuleManager = createClusterFlowRuleManager(); - allocateMappedFileService = createAllocateMappedFileService(); - messageStore = createMessageStore(allocateMappedFileService); zkClient = createCuratorZookeeperClient(); transferQueueManager = createTransferQueueManager(); consumerManager = createTransferQueueConsumerManager(); - fateRouterService = createFateRouterService(); tokenApplyService = createTokenApplyService(); - pushService = createPushService(); - requestHandleInterceptor = createDefaulRequestInterceptor(); - routerInterceptor = createDefaultRouterInterceptor(fateRouterService); - unaryCallService = createUnaryCallService(requestHandleInterceptor,routerInterceptor); - proxyGrpcService = new ProxyGrpcService(pushService, unaryCallService); transferServer = new OsxServer(); defaultTokenService = createDefaultTokenService(); tokenApplyService = createTokenApplyService(); - - - pcpGrpcService = createPcpGrpcService(); + eventDriverMsgManager = createEventDriverMsgManager( consumerManager, transferQueueManager); techProviderRegister = createTechProviderRegister(); + tokenValidatorRegister = createTokenValidatorRegister(); + tokenGeneratorRegister = createTokenGeneratorRegister(); + routerRegister = createRouterRegister(); + HttpClientPool.initPool(); if (!transferServer.start()) { + logger.error("server start failed"); + System.err.println("server start failed"); System.exit(-1); } else { - } - ; + }; + + + } + private static RouterRegister createRouterRegister(){ + RouterRegister routerRegister = new RouterRegister(); + routerRegister.init(); + routerRegister.start(); + return routerRegister; } + private static TokenValidatorRegister createTokenValidatorRegister(){ + TokenValidatorRegister tokenValidatorRegister = new TokenValidatorRegister(); + tokenValidatorRegister.init(); + tokenValidatorRegister.start(); + return tokenValidatorRegister; + } + + private static TokenGeneratorRegister createTokenGeneratorRegister(){ + TokenGeneratorRegister tokenGeneratorRegister = new TokenGeneratorRegister(); + tokenGeneratorRegister.init(); + tokenGeneratorRegister.start(); + return tokenGeneratorRegister; + } + + + private static EventDriverMsgManager createEventDriverMsgManager(ConsumerManager consumerManager,TransferQueueManager transferQueueManager){ + EventDriverMsgManager eventDriverMsgManager = new EventDriverMsgManager(consumerManager,transferQueueManager); + eventDriverMsgManager.init(); + eventDriverMsgManager.start(); + return eventDriverMsgManager; + } + + public static TechProviderRegister createTechProviderRegister() { - TechProviderRegister techProviderRegister = new TechProviderRegister(); - techProviderRegister.init(); - return techProviderRegister; + try { + TechProviderRegister techProviderRegister = new TechProviderRegister(); + techProviderRegister.start(); + return techProviderRegister; + }catch(Exception e){ + logger.error("tech provider create error",e); + } + return null; + } public static PcpGrpcService createPcpGrpcService() { @@ -132,45 +169,12 @@ public static ClusterFlowRuleManager createClusterFlowRuleManager() { return new ClusterFlowRuleManager(); } - public static MessageStore createMessageStore( - AllocateMappedFileService allocateMappedFileService) { - // TransferQueueManager transferQueueManager ,AllocateMappedFileService allocateMappedFileService,String path){ - MessageStore messageStore = new MessageStore(allocateMappedFileService - , MetaInfo.PROPERTY_TRANSFER_FILE_PATH_PRE + File.separator + MetaInfo.INSTANCE_ID + File.separator + "message-store"); - messageStore.start(); - return messageStore; - - } - - - public static RequestHandleInterceptor createDefaulRequestInterceptor() { - RequestHandleInterceptor requestHandleInterceptor = new RequestHandleInterceptor(); - return requestHandleInterceptor; - } - public static RouterInterceptor createDefaultRouterInterceptor(FateRouterService fateRouterService){ - RouterInterceptor routerInterceptor = new RouterInterceptor(fateRouterService); - return routerInterceptor; - } - - static FlowCounterManager createFlowCounterManager() { FlowCounterManager flowCounterManager = new FlowCounterManager("transfer"); flowCounterManager.startReport(); return flowCounterManager; } - static UnaryCallService createUnaryCallService(RequestHandleInterceptor requestHandleInterceptor,RouterInterceptor routerInterceptor) { - UnaryCallService unaryCallService = new UnaryCallService(); - unaryCallService.addPreProcessor(requestHandleInterceptor); - unaryCallService.addPreProcessor(routerInterceptor); - return unaryCallService; - } - - static PushService createPushService() { - PushService pushService = new PushService(); - return pushService; - } - static ConsumerManager createTransferQueueConsumerManager() { ConsumerManager consumerManager = new ConsumerManager(); return consumerManager; @@ -187,11 +191,7 @@ static TransferQueueManager createTransferQueueManager() { return transferQueueManager; } - static AllocateMappedFileService createAllocateMappedFileService() { - AllocateMappedFileService allocateMappedFileService = new AllocateMappedFileService(); - allocateMappedFileService.start(); - return allocateMappedFileService; - } + } diff --git a/java/osx/broker/src/main/java/com/osx/broker/buffer/BufferStatus.java b/java/osx/osx-broker/src/main/java/com/osx/broker/buffer/BufferStatus.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/buffer/BufferStatus.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/buffer/BufferStatus.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/buffer/ReadResult.java b/java/osx/osx-broker/src/main/java/com/osx/broker/buffer/ReadResult.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/buffer/ReadResult.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/buffer/ReadResult.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/buffer/ReadStatus.java b/java/osx/osx-broker/src/main/java/com/osx/broker/buffer/ReadStatus.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/buffer/ReadStatus.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/buffer/ReadStatus.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/buffer/TransferBufferUtil.java b/java/osx/osx-broker/src/main/java/com/osx/broker/buffer/TransferBufferUtil.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/buffer/TransferBufferUtil.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/buffer/TransferBufferUtil.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/buffer/WriteResult.java b/java/osx/osx-broker/src/main/java/com/osx/broker/buffer/WriteResult.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/buffer/WriteResult.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/buffer/WriteResult.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/buffer/WriteStatus.java b/java/osx/osx-broker/src/main/java/com/osx/broker/buffer/WriteStatus.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/buffer/WriteStatus.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/buffer/WriteStatus.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/callback/CompleteCallback.java b/java/osx/osx-broker/src/main/java/com/osx/broker/callback/CompleteCallback.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/callback/CompleteCallback.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/callback/CompleteCallback.java diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/callback/CreateUserCallback.java b/java/osx/osx-broker/src/main/java/com/osx/broker/callback/CreateUserCallback.java new file mode 100644 index 0000000000..aa7096e7bc --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/callback/CreateUserCallback.java @@ -0,0 +1,37 @@ +package com.osx.broker.callback; + +import com.osx.broker.ServiceContainer; +import com.osx.broker.consumer.GrpcEventHandler; +import com.osx.broker.eggroll.PushEventHandler; +import com.osx.broker.message.Message; +import com.osx.broker.message.MessageExt; +import com.osx.broker.queue.TransferQueue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class CreateUserCallback implements MsgEventCallback{ + + Logger logger = LoggerFactory.getLogger(CreateUserCallback.class); + public CreateUserCallback(Class eventHandlerClass){ + this.grpcEventHandlerClass = eventHandlerClass; + + } + Class grpcEventHandlerClass ; + + @Override + public void callback(TransferQueue queue , MessageExt message) { + String topic = queue.getTransferId(); + if(ServiceContainer.consumerManager.getEventDrivenConsumer(topic)==null){ + GrpcEventHandler grpcEventHandler = null; + try { + grpcEventHandler = (GrpcEventHandler)grpcEventHandlerClass.newInstance(); + } catch (InstantiationException e) { + throw new RuntimeException(e); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + ServiceContainer.consumerManager.createEventDrivenConsumer(topic,grpcEventHandler); + }; + } + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/callback/DestoryCallback.java b/java/osx/osx-broker/src/main/java/com/osx/broker/callback/DestoryCallback.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/callback/DestoryCallback.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/callback/DestoryCallback.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/callback/ErrorCallback.java b/java/osx/osx-broker/src/main/java/com/osx/broker/callback/ErrorCallback.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/callback/ErrorCallback.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/callback/ErrorCallback.java diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/callback/MockDesGrpcEventHandler.java b/java/osx/osx-broker/src/main/java/com/osx/broker/callback/MockDesGrpcEventHandler.java new file mode 100644 index 0000000000..9d07963e67 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/callback/MockDesGrpcEventHandler.java @@ -0,0 +1,78 @@ +//package com.osx.broker.callback; +// +//import com.google.protobuf.ByteString; +//import com.google.protobuf.InvalidProtocolBufferException; +//import com.osx.broker.ServiceContainer; +//import com.osx.broker.constants.MessageFlag; +//import com.osx.broker.consumer.GrpcEventHandler; +//import com.osx.broker.consumer.MessageEvent; +//import com.osx.broker.message.MessageExt; +//import com.osx.broker.util.TransferUtil; +//import com.osx.core.constant.Dict; +//import com.osx.core.constant.TransferStatus; +//import com.osx.core.frame.GrpcConnectionFactory; +//import com.osx.core.ptp.TargetMethod; +//import com.osx.core.router.RouterInfo; +//import com.webank.ai.eggroll.api.networking.proxy.DataTransferServiceGrpc; +//import io.grpc.ManagedChannel; +//import org.ppc.ptp.Osx; +//import org.ppc.ptp.PrivateTransferProtocolGrpc; +//import org.slf4j.Logger; +//import org.slf4j.LoggerFactory; +// +//import java.nio.charset.StandardCharsets; +// +//public class MockDesGrpcEventHandler extends GrpcEventHandler { +// +// +// +// Logger logger = LoggerFactory.getLogger(MockDesGrpcEventHandler.class); +// +// PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub blockingStub; +// @Override +// protected void handleMessage(MessageExt message) { +// +// String topic = message.getTopic(); +// String srcPartyId = message.getSrcPartyId(); +// String desPartyId = message .getDesPartyId(); +// try { +// Osx.Inbound inbound = Osx.Inbound.parseFrom(message.getBody()); +// logger.info("receive message topic {} srcPartyId {} desPartyId {} msg {}",topic,srcPartyId,desPartyId,new String(inbound.getPayload().toByteArray())); +// } catch (InvalidProtocolBufferException e) { +// e.printStackTrace(); +// } +// +// } +// +// @Override +// protected void handleError(MessageExt message) { +// logger.info("handle error : {}",new String(message.getBody())); +// } +// +// @Override +// protected void handleComplete(MessageExt message) { +// logger.info("receive complete"); +// +// } +// +// @Override +// protected void handleInit(MessageEvent event) { +// +// logger.info("init================= {} {} {} {} {}",topic, backTopic,srcPartyId,desPartyId,sessionId); +// new Thread(new Runnable() { +// @Override +// public void run() { +// for(int i=0;i<10;i++){ +// +// Osx.Outbound outBound = Osx.Outbound.newBuilder().setPayload(ByteString.copyFrom("my name is god".getBytes(StandardCharsets.UTF_8))).build(); +// sendBackMsg(outBound.toByteArray()); +// if(i==9){ +// sendBackCompleted(); +// } +// } +// } +// }).start(); +// } +// +// +//} diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/callback/MsgEventCallback.java b/java/osx/osx-broker/src/main/java/com/osx/broker/callback/MsgEventCallback.java new file mode 100644 index 0000000000..0ee602bf2f --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/callback/MsgEventCallback.java @@ -0,0 +1,10 @@ +package com.osx.broker.callback; + +import com.osx.broker.message.Message; +import com.osx.broker.message.MessageExt; +import com.osx.broker.queue.TransferQueue; + +@FunctionalInterface +public interface MsgEventCallback { + void callback(TransferQueue transferQueue , MessageExt message); +} diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/callback/MsgEventDispatchCallback.java b/java/osx/osx-broker/src/main/java/com/osx/broker/callback/MsgEventDispatchCallback.java new file mode 100644 index 0000000000..89517c6683 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/callback/MsgEventDispatchCallback.java @@ -0,0 +1,29 @@ +package com.osx.broker.callback; + +import com.osx.broker.ServiceContainer; +import com.osx.broker.consumer.EventDrivenConsumer; +import com.osx.broker.consumer.MessageEvent; +import com.osx.broker.message.Message; +import com.osx.broker.message.MessageExt; +import com.osx.broker.queue.TransferQueue; +import com.osx.core.constant.Dict; + +import javax.xml.ws.Service; + +public class MsgEventDispatchCallback implements MsgEventCallback{ + @Override + public void callback(TransferQueue transferQueue, MessageExt message) { + String topic = transferQueue.getTransferId(); + EventDrivenConsumer eventDrivenConsumer = ServiceContainer.consumerManager.getEventDrivenConsumer(topic); + if(eventDrivenConsumer!=null){ + MessageEvent messageEvent = new MessageEvent(); + messageEvent.setTopic(topic); + messageEvent.setDesComponent(message.getProperty(Dict.DES_COMPONENT)); + messageEvent.setSrcComponent(message.getProperty(Dict.SOURCE_COMPONENT)); + messageEvent.setSrcPartyId(message.getSrcPartyId()); + messageEvent.setDesPartyId(message.getDesPartyId()); + messageEvent.setSessionId(message.getProperty(Dict.SESSION_ID)); + eventDrivenConsumer.fireEvent(messageEvent); + } + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/constants/Direction.java b/java/osx/osx-broker/src/main/java/com/osx/broker/constants/Direction.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/constants/Direction.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/constants/Direction.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/constants/MessageFlag.java b/java/osx/osx-broker/src/main/java/com/osx/broker/constants/MessageFlag.java similarity index 89% rename from java/osx/broker/src/main/java/com/osx/broker/constants/MessageFlag.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/constants/MessageFlag.java index ccd8c7fc52..7be0614e39 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/constants/MessageFlag.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/constants/MessageFlag.java @@ -17,7 +17,7 @@ public enum MessageFlag { - MSG(0), ERROR(1), COMPELETED(2); + SENDMSG(0), ERROR(1), COMPELETED(2),BACKMSG(3); private int flag; @@ -28,11 +28,13 @@ private MessageFlag(int flag) { static public MessageFlag getMessageFlag(int flag) { switch (flag) { case 0: - return MSG; + return SENDMSG; case 1: return ERROR; case 2: return COMPELETED; + case 3: + return BACKMSG; default: return null; } diff --git a/java/osx/broker/src/main/java/com/osx/broker/consumer/ConsumerManager.java b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/ConsumerManager.java similarity index 67% rename from java/osx/broker/src/main/java/com/osx/broker/consumer/ConsumerManager.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/consumer/ConsumerManager.java index 9c8c58231e..4a0261463c 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/consumer/ConsumerManager.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/ConsumerManager.java @@ -14,33 +14,48 @@ * limitations under the License. */ package com.osx.broker.consumer; + import com.google.common.collect.Maps; +import com.lmax.disruptor.EventHandler; +import com.osx.core.frame.Lifecycle; import com.osx.core.frame.ServiceThread; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; -public class ConsumerManager { +public class ConsumerManager implements Lifecycle { Logger logger = LoggerFactory.getLogger(ConsumerManager.class); - ScheduledExecutorService scheduledExecutorService = new ScheduledThreadPoolExecutor(1); ConcurrentHashMap unaryConsumerMap = new ConcurrentHashMap<>(); - ConcurrentHashMap streamConsumerMap = new ConcurrentHashMap<>(); - ConcurrentHashMap redirectConsumerMap = new ConcurrentHashMap<>(); + ConcurrentHashMap eventDrivenConsumerMap = new ConcurrentHashMap<>(); AtomicLong consumerIdIndex = new AtomicLong(0); + ServiceThread monitorThread = new ServiceThread() { + @Override + public String getServiceName() { + return "monitor"; + } + + @Override + public void run() { + while (true) { + try { + report(); + } catch (Exception igore) { + } + this.waitForRunning(60000); + } + } + }; ServiceThread longPullingThread = new ServiceThread() { @Override public String getServiceName() { return "longPullingThread"; } - @Override public void run() { int interval = 200; @@ -55,9 +70,10 @@ public void run() { answerCount.addAndGet(unaryConsumer.answerLongPulling()); longPullingWaitingSize.addAndGet(unaryConsumer.getLongPullingQueueSize()); } catch (Exception igore) { - + igore.printStackTrace(); } }); + if (longPullingWaitingSize.get() > 0) { interval = 500; } else { @@ -66,15 +82,16 @@ public void run() { } catch (Exception igore) { } + this.waitForRunning(interval); } } }; - public ConsumerManager() { longPullingThread.start(); + monitorThread.start(); } public Map getUnaryConsumerMap() { @@ -95,6 +112,24 @@ public UnaryConsumer getUnaryConsumer(String transferId) { return unaryConsumerMap.get(transferId); } + public EventDrivenConsumer getEventDrivenConsumer(String topic){ + + return this.eventDrivenConsumerMap.get(topic); + + } + + public EventDrivenConsumer createEventDrivenConsumer(String topic, EventHandler eventHandler){ + logger.info("create event driven consumer , {}",topic); + if (eventDrivenConsumerMap.get(topic) == null) { + EventDrivenConsumer eventDrivenConsumer = + new EventDrivenConsumer(consumerIdIndex.get(), topic,eventHandler); + eventDrivenConsumerMap.putIfAbsent(topic, eventDrivenConsumer); + return eventDrivenConsumerMap.get(topic); + } else { + return eventDrivenConsumerMap.get(topic); + } + } + public UnaryConsumer getOrCreateUnaryConsumer(String transferId) { if (unaryConsumerMap.get(transferId) == null) { UnaryConsumer unaryConsumer = @@ -106,50 +141,31 @@ public UnaryConsumer getOrCreateUnaryConsumer(String transferId) { } } - public StreamConsumer getOrCreateStreamConsumer(String transferId) { - - if (streamConsumerMap.get(transferId) == null) { - StreamConsumer streamConsumer = new StreamConsumer(consumerIdIndex.get(), transferId); - streamConsumerMap.putIfAbsent(transferId, streamConsumer); - return streamConsumerMap.get(transferId); - } else { - return streamConsumerMap.get(transferId); - } + public void onComplete(String transferId) { + this.unaryConsumerMap.remove(transferId); + logger.info("remove consumer {}", transferId); } - public synchronized RedirectConsumer getOrCreateRedirectConsumer(String resource) { - logger.info("getOrCreateRedirectConsumer {}", resource); - if (unaryConsumerMap.get(resource) == null) { - RedirectConsumer redirectConsumer = - new RedirectConsumer(consumerIdIndex.get(), resource); - unaryConsumerMap.putIfAbsent(resource, redirectConsumer); - return (RedirectConsumer) unaryConsumerMap.get(resource); - } else { - return (RedirectConsumer) unaryConsumerMap.get(resource); - } + private void checkAndClean() { + } + @Override + public void init() { + + -// public synchronized PushConsumer getOrCreatePushConsumer(String transferId){ -// if (pushConsumerMap.get(transferId) == null) { -// PushConsumer pushConsumer = -// new PushConsumer(consumerIdIndex.get(), transferId); -// pushConsumerMap.putIfAbsent(transferId,pushConsumer); -// return pushConsumerMap.get(transferId); -// } else { -// return pushConsumerMap.get(transferId); -// } -// } - public void onComplete(String transferId) { - this.unaryConsumerMap.remove(transferId); - logger.info("remove consumer {}", transferId); } - /** - * - */ - private void checkAndClean() { + @Override + public void start() { + + } + + @Override + public void destroy() { + } public static class ReportData { diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/EventDrivenConsumer.java b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/EventDrivenConsumer.java new file mode 100644 index 0000000000..426bd9846f --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/EventDrivenConsumer.java @@ -0,0 +1,52 @@ +package com.osx.broker.consumer; + +import com.lmax.disruptor.BlockingWaitStrategy; +import com.lmax.disruptor.EventHandler; +import com.lmax.disruptor.EventTranslatorOneArg; +import com.lmax.disruptor.dsl.Disruptor; +import com.lmax.disruptor.dsl.ProducerType; +import com.lmax.disruptor.util.DaemonThreadFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +public class EventDrivenConsumer extends LocalQueueConsumer { + + Logger logger = LoggerFactory.getLogger(EventDrivenConsumer.class); + EventHandler eventHandler; + Disruptor disruptor; + + public EventDrivenConsumer(long consumerId, String topic,EventHandler eventHandler){ + + super(consumerId,topic); + this.eventHandler = eventHandler; + disruptor = new Disruptor(() -> new MessageEvent(), + 2048, DaemonThreadFactory.INSTANCE, + ProducerType.SINGLE, new BlockingWaitStrategy()); + disruptor.handleEventsWith(eventHandler); + disruptor.start(); + logger.info("new EventDrivenConsumer {}",topic); + + } + public static final EventTranslatorOneArg TRANSLATOR = + (event, sequence, arg) -> { + event.setTopic(arg.getTopic()); + event.setDesPartyId(arg.getDesPartyId()); + event.setSrcComponent(arg.getSrcComponent()); + event.setSrcPartyId(arg.getSrcPartyId()); + event.setDesComponent(arg.getDesComponent()); + event.setSessionId(arg.getSessionId()); + }; + + public void fireEvent(MessageEvent event){ + disruptor.publishEvent((EventTranslatorOneArg) TRANSLATOR,event); + } + + public static void main(String[] args){ +// MessageEvent messageEvent = new MessageEvent(); +// EventDrivenConsumer eventDrivenConsumer = new EventDrivenConsumer(0,"test",new MockDesGrpcEventHandler()); +// eventDrivenConsumer.fireEvent(messageEvent); + + } + +} diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/EventDriverRule.java b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/EventDriverRule.java new file mode 100644 index 0000000000..082e256419 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/EventDriverRule.java @@ -0,0 +1,7 @@ +package com.osx.broker.consumer; + +import com.osx.broker.queue.TransferQueue; +@FunctionalInterface +public interface EventDriverRule { + boolean isMatch(TransferQueue queue); +} diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/GrpcEventHandler.java b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/GrpcEventHandler.java new file mode 100644 index 0000000000..49a7cfe621 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/GrpcEventHandler.java @@ -0,0 +1,150 @@ +package com.osx.broker.consumer; + +import com.lmax.disruptor.EventHandler; +import com.osx.api.context.Context; +import com.osx.api.router.RouterInfo; +import com.osx.broker.ServiceContainer; +import com.osx.broker.constants.MessageFlag; +import com.osx.broker.message.MessageExt; +import com.osx.broker.queue.TransferQueue; +import com.osx.broker.util.TransferUtil; +import com.osx.core.config.MetaInfo; +import com.osx.core.constant.Dict; +import com.osx.core.constant.StatusCode; +import com.osx.core.constant.TransferStatus; +import com.osx.core.context.FateContext; +import com.osx.core.exceptions.ExceptionInfo; +import com.osx.core.frame.GrpcConnectionFactory; +import com.osx.core.ptp.TargetMethod; +import io.grpc.ManagedChannel; +import org.ppc.ptp.Osx; +import org.ppc.ptp.PrivateTransferProtocolGrpc; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.charset.StandardCharsets; + +public abstract class GrpcEventHandler implements EventHandler { + + Logger logger = LoggerFactory.getLogger(GrpcEventHandler.class); + public GrpcEventHandler(String provider){ + this.provider = provider; + } + protected TransferStatus transferStatus = TransferStatus.INIT; + protected String provider; + protected String srcPartyId; + protected String desPartyId; + protected String sessionId; + protected String srcComponent; + protected String desComponent; + protected String topic; + protected String backTopic; + protected RouterInfo backRouterInfo; + protected FateContext context; + + public void sendBackException(ExceptionInfo e){ + if(transferStatus==TransferStatus.TRANSFERING) { + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(provider,desPartyId, srcPartyId, + TargetMethod.PRODUCE_MSG.name(), backTopic, MessageFlag.COMPELETED, sessionId,e.toString().getBytes(StandardCharsets.UTF_8) ); + TransferUtil.redirect(context,inboundBuilder.build(),backRouterInfo,true); + + }else{ + logger.error("!!!!!!!!!transferStatus is {}",transferStatus); + } + }; + + public void sendBackCompleted(){ + if(transferStatus== TransferStatus.TRANSFERING) { + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(provider,desPartyId, srcPartyId, + TargetMethod.PRODUCE_MSG.name(), backTopic, MessageFlag.COMPELETED, sessionId, "completed".getBytes(StandardCharsets.UTF_8)); + TransferUtil.redirect(context,inboundBuilder.build(),backRouterInfo,true); + }else{ + logger.error("!!!!!!!!!transferStatus is {}",transferStatus); + } + } + + public void sendBackMsg(byte[] data){ + if(transferStatus== TransferStatus.TRANSFERING) { + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(provider,desPartyId, srcPartyId, + TargetMethod.PRODUCE_MSG.name(), backTopic, MessageFlag.SENDMSG, sessionId, data); + TransferUtil.redirect(context,inboundBuilder.build(),backRouterInfo,true); + }else{ + logger.error("!!!!!!!!!transferStatus is {}",transferStatus); + } + } + + protected void init(MessageEvent event){ + + if(transferStatus==TransferStatus.INIT){ + try { + context = new FateContext(); + topic = event.getTopic(); + desComponent = event.getDesComponent(); + srcComponent = event.getSrcComponent(); + srcPartyId = event.getSrcPartyId(); + desPartyId = event.getDesPartyId(); + sessionId = event.getSessionId(); + if (topic.startsWith(Dict.STREAM_SEND_TOPIC_PREFIX)) { + backTopic = topic.replaceAll(Dict.STREAM_SEND_TOPIC_PREFIX, Dict.STREAM_BACK_TOPIC_PREFIX); + } else if (topic.startsWith(Dict.STREAM_BACK_TOPIC_PREFIX)) { + backTopic = topic.replaceAll(Dict.STREAM_BACK_TOPIC_PREFIX, Dict.STREAM_SEND_TOPIC_PREFIX); + } + backRouterInfo = ServiceContainer.routerRegister.getRouterService(MetaInfo.PROPERTY_FATE_TECH_PROVIDER).route(desPartyId,"",srcPartyId,""); + handleInit(event); + transferStatus = TransferStatus.TRANSFERING; + }catch(Throwable e){ + logger.error("grpc event handler init error",e); + transferStatus = TransferStatus.ERROR; + } + } + + + } + + + @Override + public void onEvent(MessageEvent event, long l, boolean b) throws Exception { + + String topic = event.getTopic(); + // logger.info("======event {}",event); + init(event); + if(transferStatus==TransferStatus.TRANSFERING) { + EventDrivenConsumer consumer = ServiceContainer.consumerManager.getEventDrivenConsumer(topic); + TransferQueue.TransferQueueConsumeResult transferQueueConsumeResult = consumer.consume(new FateContext(), -1); + + if (transferQueueConsumeResult.getCode().equals(StatusCode.SUCCESS)) { + long index = transferQueueConsumeResult.getRequestIndex(); + //ack 的位置需要调整 + consumer.ack(index); + MessageExt messageExt = transferQueueConsumeResult.getMessage(); + + int flag = messageExt.getFlag(); + // logger.info("message flag {}", flag); + switch (flag) { + //msg + case 0: + handleMessage(messageExt); + break; + //error + case 1: + handleError(messageExt); + break; + //completed + case 2: + handleComplete(messageExt); + break; + default: + ; + } + } else { + logger.warn("consume error {}", transferQueueConsumeResult); + } + } + } + + protected abstract void handleMessage(MessageExt message); + protected abstract void handleError(MessageExt message); + protected abstract void handleComplete(MessageExt message); + protected abstract void handleInit(MessageEvent event); + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/consumer/LocalQueueConsumer.java b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/LocalQueueConsumer.java similarity index 96% rename from java/osx/broker/src/main/java/com/osx/broker/consumer/LocalQueueConsumer.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/consumer/LocalQueueConsumer.java index 4521f4a909..eede9b5c89 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/consumer/LocalQueueConsumer.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/LocalQueueConsumer.java @@ -16,13 +16,13 @@ package com.osx.broker.consumer; +import com.osx.api.context.Context; import com.osx.broker.ServiceContainer; import com.osx.broker.message.SelectMappedBufferResult; import com.osx.broker.queue.Consumer; import com.osx.broker.queue.TransferQueue; import com.osx.core.constant.StatusCode; import com.osx.core.constant.TransferStatus; -import com.osx.core.context.Context; import com.osx.core.exceptions.AckIndexException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -63,6 +63,7 @@ public boolean checkMsgIsArrive(long consumeOffset) { TransferQueue transferQueue = ServiceContainer.transferQueueManager.getQueue(transferId); if (transferQueue != null) { long indexFileOffset = transferQueue.getIndexQueue().getLogicOffset().get(); + logger.info("topic {} need consume {} , {} inqueue",transferId,consumeOffset, indexFileOffset); return consumeOffset <= indexFileOffset; } return false; diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/MessageEvent.java b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/MessageEvent.java new file mode 100644 index 0000000000..adf4efc215 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/MessageEvent.java @@ -0,0 +1,13 @@ +package com.osx.broker.consumer; + +import lombok.Data; + +@Data +public class MessageEvent { + String srcPartyId; + String desPartyId; + String srcComponent; + String desComponent; + String topic; + String sessionId ; +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/consumer/RedirectConsumer.java b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/RedirectConsumer.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/consumer/RedirectConsumer.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/consumer/RedirectConsumer.java index 76e85dcd01..0666f56c89 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/consumer/RedirectConsumer.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/RedirectConsumer.java @@ -15,8 +15,9 @@ */ package com.osx.broker.consumer; +import com.osx.api.router.RouterInfo; import com.osx.core.constant.TransferStatus; -import com.osx.core.router.RouterInfo; + import java.util.concurrent.atomic.AtomicBoolean; diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/SourceGrpcEventHandler.java b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/SourceGrpcEventHandler.java new file mode 100644 index 0000000000..d9a5115218 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/SourceGrpcEventHandler.java @@ -0,0 +1,65 @@ +package com.osx.broker.consumer; + +import com.google.protobuf.InvalidProtocolBufferException; +import com.osx.broker.ServiceContainer; +import com.osx.broker.message.MessageExt; +import com.osx.core.config.MetaInfo; +import com.osx.core.exceptions.ExceptionInfo; +import com.osx.core.utils.JsonUtil; +import io.grpc.stub.StreamObserver; + +/** + * 放在源头,用于接听远端返回 + */ +public class SourceGrpcEventHandler extends GrpcEventHandler{ + + com.google.protobuf.Parser parser; + StreamObserver respStreamObserver; + + + public SourceGrpcEventHandler(StreamObserver respStreamObserver, + com.google.protobuf.Parser parser){ + super(MetaInfo.PROPERTY_FATE_TECH_PROVIDER); + this.parser=parser; + this.respStreamObserver = respStreamObserver; + } + + @Override + protected void handleMessage(MessageExt message) { + + try { + Object data = parser.parseFrom(message.getBody()); + respStreamObserver.onNext(data); + } catch (InvalidProtocolBufferException e) { + logger.error(""); + } + } + + @Override + protected void handleError(MessageExt message) { + try { + ExceptionInfo exceptionInfo = JsonUtil.json2Object(message.getBody(), ExceptionInfo.class); + respStreamObserver.onError(new Throwable(exceptionInfo.getMessage())); + }finally { + String topic =message.getTopic(); + ServiceContainer.transferQueueManager.onCompleted(topic); + } + + } + + @Override + protected void handleComplete(MessageExt message) { + try { + respStreamObserver.onCompleted(); + }finally { + String topic =message.getTopic(); + ServiceContainer.transferQueueManager.onCompleted(topic); + } + + } + + @Override + protected void handleInit(MessageEvent event) { + + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/consumer/StreamConsumer.java b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/StreamConsumer.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/consumer/StreamConsumer.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/consumer/StreamConsumer.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/consumer/UnaryConsumer.java b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/UnaryConsumer.java similarity index 74% rename from java/osx/broker/src/main/java/com/osx/broker/consumer/UnaryConsumer.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/consumer/UnaryConsumer.java index e549831697..31a6bc5c64 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/consumer/UnaryConsumer.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/consumer/UnaryConsumer.java @@ -17,9 +17,10 @@ import com.osx.broker.ServiceContainer; import com.osx.broker.queue.TransferQueue; +import com.osx.broker.util.TransferUtil; import com.osx.core.constant.ActionType; import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; +import com.osx.core.context.FateContext; import com.osx.core.utils.FlowLogUtil; import io.grpc.stub.StreamObserver; import lombok.Data; @@ -27,6 +28,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.servlet.http.HttpServletResponse; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ConcurrentLinkedQueue; @@ -42,11 +44,11 @@ public UnaryConsumer(long consumerId, String transferId) { super(consumerId, transferId); TransferQueue transferQueue = ServiceContainer.transferQueueManager.getQueue(transferId); if (transferQueue != null) { - transferQueue.registeDestoryCallback(() -> { + transferQueue.registerDestoryCallback(() -> { ServiceContainer.consumerManager.onComplete(transferId); }); } - longPullingQueue = new ConcurrentLinkedQueue(); + longPullingQueue = new ConcurrentLinkedQueue<>(); } public int getLongPullingQueueSize() { @@ -63,6 +65,8 @@ public void addLongPullingQueue(LongPullingHold longPullingHold) { } public synchronized int answerLongPulling() { + + /* * 这里需要改为ack 后才加1 ,要不然这里会丢消息 */ @@ -71,11 +75,18 @@ public synchronized int answerLongPulling() { List reputList = null; while (this.longPullingQueue.size() > 0) { try { + long indexFileOffset = transferQueue.getIndexQueue().getLogicOffset().get(); LongPullingHold longPullingHold = this.longPullingQueue.poll(); - //StreamObserver streamObserver = longPullingHold.getStreamObserver(); + long current= System.currentTimeMillis(); long needOffset = longPullingHold.getNeedOffset(); - Context context = longPullingHold.getContext(); + + if( longPullingHold.getExpireTimestamp()>0&¤t>longPullingHold.getExpireTimestamp()){ + handleExpire(longPullingHold); + continue; + } + + FateContext context = longPullingHold.getContext(); context.setActionType(ActionType.LONG_PULLING_ANSWER.getAlias()); TransferQueue.TransferQueueConsumeResult consumeResult = null; if (needOffset <= 0) { @@ -92,7 +103,6 @@ public synchronized int answerLongPulling() { * client 传入的offset 小于等于index,可以消费 */ consumeResult = this.consume(context, needOffset); - } } @@ -101,8 +111,7 @@ public synchronized int answerLongPulling() { context.setDataSize(consumeResult.getMessage().getBody().length); Osx.Outbound consumeResponse = buildResponse(StatusCode.SUCCESS, "success", consumeResult); answerCount++; - longPullingHold.getStreamObserver().onNext(consumeResponse); - longPullingHold.getStreamObserver().onCompleted(); + longPullingHold.answer(consumeResponse); context.setTopic(transferQueue.getTransferId()); context.setReturnCode(StatusCode.SUCCESS); context.setRequestMsgIndex(consumeResult.getRequestIndex()); @@ -115,10 +124,9 @@ public synchronized int answerLongPulling() { if (reputList == null) reputList = new ArrayList<>(); reputList.add(longPullingHold); - } - } catch (Exception igore) { - + } catch (Exception e) { + logger.error("topic {} answer long pulling error ",transferId,e); } } if (reputList != null) { @@ -127,11 +135,30 @@ public synchronized int answerLongPulling() { return answerCount; } + private void handleExpire(LongPullingHold longPullingHold){ + Osx.Outbound consumeResponse = buildResponse(StatusCode.CONSUME_MSG_TIMEOUT, "CONSUME_MSG_TIMEOUT", null); + longPullingHold.answer(consumeResponse); + } + @Data public static class LongPullingHold { - Context context; + Logger logger = LoggerFactory.getLogger(LongPullingHold.class); + FateContext context; StreamObserver streamObserver; + HttpServletResponse httpServletResponse; + long expireTimestamp; long needOffset; + + public void answer(Osx.Outbound consumeResponse){ + logger.info("============ answer long pulling========"); + + if(streamObserver!=null) { + streamObserver.onNext(consumeResponse); + streamObserver.onCompleted(); + }else if(httpServletResponse!=null){ + TransferUtil.writeHttpRespose(httpServletResponse,consumeResponse.getCode(),consumeResponse.getMessage(),consumeResponse.getPayload()!=null?consumeResponse.getPayload().toByteArray():null); + } + } } } diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/BaseProto.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/BaseProto.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/BaseProto.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/BaseProto.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ClusterManagerClient.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ClusterManagerClient.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ClusterManagerClient.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ClusterManagerClient.java index 7c0c772d4b..08f140d897 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ClusterManagerClient.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ClusterManagerClient.java @@ -92,8 +92,8 @@ public ErStore getOrCreateStore(ErStore input) { try { Meta.Store oriStore = Meta.Store.parseFrom(result.get(0)); resultErStore = ErStore.parseFromPb(oriStore); - } catch (InvalidProtocolBufferException e) { - e.printStackTrace(); + } catch (InvalidProtocolBufferException igore) { + } } return resultErStore; diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/CommandClient.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/CommandClient.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/CommandClient.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/CommandClient.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/CommandURI.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/CommandURI.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/CommandURI.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/CommandURI.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErEndpoint.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErEndpoint.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErEndpoint.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErEndpoint.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErFunctor.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErFunctor.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErFunctor.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErFunctor.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErJob.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErJob.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErJob.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErJob.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErPartition.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErPartition.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErPartition.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErPartition.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErProcessor.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErProcessor.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErProcessor.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErProcessor.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErRollSiteHeader.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErRollSiteHeader.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErRollSiteHeader.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErRollSiteHeader.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErSession.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErSession.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErSession.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErSession.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErSessionMeta.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErSessionMeta.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErSessionMeta.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErSessionMeta.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErStore.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErStore.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErStore.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErStore.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErStoreLocator.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErStoreLocator.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErStoreLocator.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErStoreLocator.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErTask.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErTask.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErTask.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/ErTask.java diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/EventDriverMsgManager.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/EventDriverMsgManager.java new file mode 100644 index 0000000000..5f2670713f --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/EventDriverMsgManager.java @@ -0,0 +1,69 @@ +package com.osx.broker.eggroll; + +import com.google.common.collect.Lists; +import com.lmax.disruptor.BlockingWaitStrategy; +import com.lmax.disruptor.dsl.Disruptor; +import com.lmax.disruptor.dsl.ProducerType; +import com.lmax.disruptor.util.DaemonThreadFactory; +import com.osx.broker.ServiceContainer; +import com.osx.broker.callback.CreateUserCallback; +//import com.osx.broker.callback.MockDesGrpcEventHandler; +import com.osx.broker.callback.MsgEventCallback; +import com.osx.broker.callback.MsgEventDispatchCallback; +import com.osx.broker.consumer.ConsumerManager; +import com.osx.broker.message.Message; +import com.osx.broker.queue.TransferQueue; +import com.osx.broker.queue.TransferQueueManager; +import com.osx.core.constant.Dict; +import com.osx.core.frame.Lifecycle; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.ConcurrentHashMap; + +public class EventDriverMsgManager implements Lifecycle { + + Logger logger = LoggerFactory.getLogger(EventDriverMsgManager.class); + ConsumerManager consumerManager=null; + TransferQueueManager transferQueueManager=null; + public EventDriverMsgManager(ConsumerManager consumerManager,TransferQueueManager transferQueueManager){ + this.consumerManager = consumerManager; + this.transferQueueManager = transferQueueManager; + } + + + + + @Override + public void init() { + MsgEventDispatchCallback dispatchCallback = new MsgEventDispatchCallback(); + ServiceContainer.transferQueueManager.addMsgCallBackRule((queue -> { + if(queue.getTransferId().startsWith(Dict.STREAM_SEND_TOPIC_PREFIX)){ + return true; + } + return false; + }), Lists.newArrayList(new CreateUserCallback(PushEventHandler.class),dispatchCallback)); + ServiceContainer.transferQueueManager.addMsgCallBackRule((queue -> { + if(queue.getTransferId().startsWith(Dict.STREAM_BACK_TOPIC_PREFIX)){ + return true; + } + return false; + }), Lists.newArrayList(dispatchCallback)); + + } + + + + + + @Override + public void start() { + + + } + + @Override + public void destroy() { + + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/IdUtils.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/IdUtils.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/IdUtils.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/IdUtils.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/MetaCommnads.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/MetaCommnads.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/MetaCommnads.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/MetaCommnads.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/PartitionerTypes.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/PartitionerTypes.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/PartitionerTypes.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/PartitionerTypes.java diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/PushEventHandler.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/PushEventHandler.java new file mode 100644 index 0000000000..d111fa8b23 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/PushEventHandler.java @@ -0,0 +1,291 @@ +package com.osx.broker.eggroll; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import com.osx.api.constants.Protocol; +import com.osx.api.router.RouterInfo; +import com.osx.broker.ServiceContainer; +import com.osx.broker.constants.MessageFlag; +import com.osx.broker.consumer.GrpcEventHandler; +import com.osx.broker.consumer.MessageEvent; +import com.osx.broker.message.MessageExt; +import com.osx.broker.util.TransferUtil; +import com.osx.core.config.MetaInfo; +import com.osx.core.constant.*; +import com.osx.core.context.FateContext; +import com.osx.core.exceptions.*; +import com.osx.core.frame.GrpcConnectionFactory; +import com.osx.core.ptp.TargetMethod; +import com.osx.core.utils.ToStringUtils; +import com.webank.ai.eggroll.api.networking.proxy.Proxy; +import com.webank.eggroll.core.command.Command; +import com.webank.eggroll.core.meta.Meta; +import com.webank.eggroll.core.transfer.Transfer; +import com.webank.eggroll.core.transfer.TransferServiceGrpc; +import io.grpc.ManagedChannel; +import io.grpc.stub.StreamObserver; +import org.apache.commons.lang3.StringUtils; +import org.ppc.ptp.Osx; +import org.ppc.ptp.PrivateTransferProtocolGrpc; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; + +public class PushEventHandler extends GrpcEventHandler { + Logger logger = LoggerFactory.getLogger(PushEventHandler.class); + public PushEventHandler(){ + super(MetaInfo.PROPERTY_FATE_TECH_PROVIDER); + } + TransferStatus transferStatus= TransferStatus.INIT; + FateContext context = new FateContext(); + RouterInfo routerInfo ; + Proxy.Metadata metadata; + String brokerTag; + ErRollSiteHeader rsHeader = null; + CountDownLatch finishLatch; + StreamObserver putBatchSinkPushReqSO; + String topic = null; + String backTopic = null; + PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub backBlockingStub; + String desRole = null; + String srcRole = null; + String sessionId = null; + RouterInfo revertRouterInfo; + + protected void handleError(MessageExt messageExt){ + //todo + // 需要构建新异常 + try { + + if (putBatchSinkPushReqSO != null) { + putBatchSinkPushReqSO.onError(new Exception()); + } + }finally { + String topic = messageExt.getTopic(); + ServiceContainer.transferQueueManager.onCompleted(topic); + } + } + + protected void handleComplete(MessageExt messageExt){ + try { + if (putBatchSinkPushReqSO != null) { + putBatchSinkPushReqSO.onCompleted(); + } + }finally { + String topic = messageExt.getTopic(); + ServiceContainer.transferQueueManager.onCompleted(topic); + } + + + } + + @Override + protected void handleInit(MessageEvent event) { + + } + + protected void handleMessage(MessageExt messageExt){ + try { + Proxy.Packet packet=null; + try { + packet = Proxy.Packet.parseFrom(messageExt.getBody()); + }catch (Exception e){ + logger.error("parse packet error {}",new String(messageExt.getBody())); + } + if (transferStatus.equals(TransferStatus.INIT)) { + //初始化 + try { + initEggroll(packet,messageExt); + }catch(Exception e){ + logger.error("init eggroll error",e); + transferStatus=TransferStatus.ERROR; + } + } + if (!transferStatus.equals(TransferStatus.TRANSFERING)) { + throw new RemoteRpcException("eggroll init error"); + } + + Transfer.TransferHeader.Builder transferHeaderBuilder = Transfer.TransferHeader.newBuilder(); + Transfer.TransferHeader tbHeader = transferHeaderBuilder.setId((int) metadata.getSeq()) + .setTag(brokerTag) + .setExt(packet.getHeader().getExt()).build(); + Transfer.TransferBatch.Builder transferBatchBuilder = Transfer.TransferBatch.newBuilder(); + Transfer.TransferBatch tbBatch = transferBatchBuilder.setHeader(tbHeader) + .setData(packet.getBody().getValue()) + .build(); + putBatchSinkPushReqSO.onNext(tbBatch); + + }catch(Exception e){ + logger.error("======handle msg error : "+ messageExt.getTopic(),e); + if(backBlockingStub!=null) { + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(provider,desPartyId, srcPartyId, TargetMethod.PRODUCE_MSG.name(), + backTopic, MessageFlag.ERROR, sessionId, ErrorMessageUtil.buildRemoteRpcErrorMsg(1343,"kkkkk").getBytes()); + Osx.Outbound outbound = backBlockingStub.invoke(inboundBuilder.build()); + }else{ + logger.error("back stub is null"); + } + } + } + + private void initEggroll(Proxy.Packet firstRequest,MessageExt messageExt) throws Exception { + if (StringUtils.isEmpty(MetaInfo.PROPERTY_EGGROLL_CLUSTER_MANANGER_IP)) { + throw new SysException("eggroll cluter manager ip is not found"); + } + + topic = messageExt.getTopic(); + backTopic= buildBackTopic(topic); + metadata = firstRequest.getHeader(); + ByteString encodedRollSiteHeader = metadata.getExt(); + rsHeader = ErRollSiteHeader.parseFromPb(Transfer.RollSiteHeader.parseFrom(encodedRollSiteHeader)); + Integer partitionId = rsHeader.getPartitionId(); + brokerTag = "putBatch-" + rsHeader.getRsKey("#", "__rsk") + "-" + partitionId; + String oneLineStringMetadata = ToStringUtils.toOneLineString(metadata); + context.setActionType(ActionType.PUSH_EGGROLL.getAlias()); + String rsKey = rsHeader.getRsKey("#", "__rsk"); + sessionId = String.join("_", rsHeader.getRollSiteSessionId(), rsHeader.getDstRole(), rsHeader.getDstPartyId()); + context.setSessionId(sessionId); + desPartyId = metadata.getDst().getPartyId(); + desRole = metadata.getDst().getRole(); + srcRole = metadata.getSrc().getRole(); + srcPartyId = metadata.getSrc().getPartyId(); + //String srcPartyId, String srcRole, String dstPartyId, String desRole + revertRouterInfo = ServiceContainer.routerRegister.getRouterService(MetaInfo.PROPERTY_FATE_TECH_PROVIDER).route(desPartyId,desRole,srcPartyId,srcRole); + if(revertRouterInfo==null){ + throw new NoRouterInfoException(srcPartyId+" can not found route info"); + } + if(Protocol.grpc.equals(revertRouterInfo.getProtocol())) { + ManagedChannel backChannel = GrpcConnectionFactory.createManagedChannel(revertRouterInfo, true); + backBlockingStub = PrivateTransferProtocolGrpc.newBlockingStub(backChannel); + context.putData(Dict.BLOCKING_STUB,backBlockingStub); + } + + + ErSession session = null; + try { + session = PutBatchSinkUtil.sessionCache.get(sessionId); + } catch (ExecutionException e) { + logger.error("get session error ", e); + } + if (!SessionStatus.ACTIVE.name().equals(session.getErSessionMeta().getStatus())) { + logger.error(""); + IllegalStateException error = new IllegalStateException("eggroll session "+sessionId+" status is "+session.getErSessionMeta().getStatus()); + // onError(error); + throw error; + } + + String namespace = rsHeader.getRollSiteSessionId(); + String name = rsKey; + RollPairContext ctx = new RollPairContext(session); + Map rpOptions = Maps.newHashMap(); + rpOptions.putAll(rsHeader.getOptions()); + rpOptions.put(Dict.TOTAL_PARTITIONS_SNAKECASE, rsHeader.getTotalPartitions().toString()); + + if (rsHeader.getDataType().equals("object")) { + rpOptions.put(Dict.SERDES, SerdesTypes.EMPTY.name()); + } else { + rpOptions.put(Dict.SERDES, rsHeader.getOptions().getOrDefault("serdes", SerdesTypes.PICKLE.name())); + } + + // table creates here + RollPair rp = ctx.load(namespace, name, rpOptions); + ErPartition partition = rp.getStore().getPartition(partitionId); + ErProcessor egg = ctx.getErSession().routeToEgg(partition); + String jobId = IdUtils.generateJobId(ctx.getErSession().getSessionId(), brokerTag, "-"); + Map jobOptions = new HashMap<>(); + + jobOptions.putAll(rsHeader.getOptions()); + jobOptions.put(SessionConfKeys.CONFKEY_SESSION_ID, ctx.getErSession().getSessionId()); + ErJob job = new ErJob( + jobId, + RollPair.PUT_BATCH, + Lists.newArrayList(rp.getStore()), + Lists.newArrayList(rp.getStore()), + Lists.newArrayList(), + jobOptions); + + ErTask task = new ErTask(brokerTag, + RollPair.PUT_BATCH, + Lists.newArrayList(partition), + Lists.newArrayList(partition), + job); + + Future commandFuture = RollPairContext.executor.submit(() -> { + CommandClient commandClient = new CommandClient(egg.getCommandEndpoint()); + Command.CommandResponse commandResponse = commandClient.call(RollPair.EGG_RUN_TASK_COMMAND, task); + long begin = System.currentTimeMillis(); + try { + Meta.Task taskMeta = Meta.Task.parseFrom(commandResponse.getResultsList().get(0)); + ErTask erTask = ErTask.parseFromPb(taskMeta); + long now = System.currentTimeMillis(); + return erTask; + } catch (InvalidProtocolBufferException igore) { + + } + return null; + }); + routerInfo = new RouterInfo(); + context.setRouterInfo(routerInfo); + routerInfo.setHost(egg.getTransferEndpoint().getHost()); + routerInfo.setPort(egg.getTransferEndpoint().getPort()); + context.setSrcPartyId(routerInfo.getSourcePartyId()); + context.setDesPartyId(routerInfo.getDesPartyId()); + ManagedChannel eggChannel = GrpcConnectionFactory.createManagedChannel(routerInfo,false); + TransferServiceGrpc.TransferServiceStub stub = TransferServiceGrpc.newStub(eggChannel); + StreamObserver eggSiteServicerPushRespSO; + putBatchSinkPushReqSO = stub.send(new PutBatchSinkPushRespSO(metadata, commandFuture, new StreamObserver(){ + + TransferStatus transferStatus = TransferStatus.INIT; + + private void init(){ + transferStatus= TransferStatus.TRANSFERING; + } + + @Override + public void onNext(Proxy.Metadata metadata) { + //将其对调后再查路由 + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(provider,desPartyId,srcPartyId,TargetMethod.PRODUCE_MSG.name(), + backTopic,MessageFlag.SENDMSG,sessionId, metadata.toByteString().toByteArray()); + TransferUtil.redirect(context,inboundBuilder.build(),revertRouterInfo,true); + } + + @Override + public void onError(Throwable throwable) { + ExceptionInfo exceptionInfo = new ExceptionInfo(); + exceptionInfo.setMessage(throwable.getMessage()); + String message = throwable.getMessage(); + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(provider,desPartyId, srcPartyId, TargetMethod.PRODUCE_MSG.name(), + backTopic, MessageFlag.SENDMSG, sessionId, exceptionInfo.toString().getBytes(StandardCharsets.UTF_8)); + TransferUtil.redirect(context,inboundBuilder.build(),revertRouterInfo,true); + + } + + @Override + public void onCompleted() { + /** + * 完成回调 + */ + try { + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(provider,desPartyId, srcPartyId, TargetMethod.PRODUCE_MSG.name(), + backTopic, MessageFlag.COMPELETED, sessionId, "completed".getBytes(StandardCharsets.UTF_8)); + Osx.Outbound result =TransferUtil.redirect(context, inboundBuilder.build(), revertRouterInfo,true); + }catch (Exception e){ + logger.error("receive completed error",e); + } + } + }, finishLatch)); + transferStatus= TransferStatus.TRANSFERING; + } + + private String buildBackTopic(String oriTopic){ + int length = Dict.STREAM_SEND_TOPIC_PREFIX.length(); + return Dict.STREAM_BACK_TOPIC_PREFIX+oriTopic.substring(length); + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkPushRespSO.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkPushRespSO.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkPushRespSO.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkPushRespSO.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkUtil.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkUtil.java similarity index 73% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkUtil.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkUtil.java index 5ca1fc72a5..2fbc02d7d7 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkUtil.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkUtil.java @@ -40,18 +40,4 @@ public ErSession load(String sessionId) throws Exception { } ); - -// object PutBatchSinkUtils { -// val sessionCache: LoadingCache[String, ErSession] = CacheBuilder.newBuilder -// .maximumSize(2000) -// .expireAfterWrite(10, TimeUnit.MINUTES) -// .concurrencyLevel(100) -// .recordStats -// .softValues -// .build(new CacheLoader[String, ErSession]() { -// override def load(key: String): ErSession = { -// new ErSession(sessionId = key, createIfNotExists = false) -// } -// }) -// } } diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/RollPair.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/RollPair.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/RollPair.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/RollPair.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/RollPairContext.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/RollPairContext.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/RollPairContext.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/RollPairContext.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/SerdesTypes.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/SerdesTypes.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/SerdesTypes.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/SerdesTypes.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/SessionCommands.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/SessionCommands.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/SessionCommands.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/SessionCommands.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/SessionConfKeys.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/SessionConfKeys.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/SessionConfKeys.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/SessionConfKeys.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/SessionStatus.java b/java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/SessionStatus.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/SessionStatus.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/eggroll/SessionStatus.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/flow/ClusterMetricStatistics.java b/java/osx/osx-broker/src/main/java/com/osx/broker/flow/ClusterMetricStatistics.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/flow/ClusterMetricStatistics.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/flow/ClusterMetricStatistics.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/flow/ClusterRuleUtil.java b/java/osx/osx-broker/src/main/java/com/osx/broker/flow/ClusterRuleUtil.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/flow/ClusterRuleUtil.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/flow/ClusterRuleUtil.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/grpc/ContextPrepareInterceptor.java b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/ContextPrepareInterceptor.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/grpc/ContextPrepareInterceptor.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/grpc/ContextPrepareInterceptor.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/grpc/ForwardPullRespSO.java b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/ForwardPullRespSO.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/grpc/ForwardPullRespSO.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/grpc/ForwardPullRespSO.java index 97ec8670de..acec8225e2 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/grpc/ForwardPullRespSO.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/ForwardPullRespSO.java @@ -16,9 +16,9 @@ package com.osx.broker.grpc; import com.google.common.base.Preconditions; +import com.osx.api.context.Context; import com.osx.broker.constants.Direction; import com.osx.broker.util.ResourceUtil; -import com.osx.core.context.Context; import com.webank.ai.eggroll.api.networking.proxy.Proxy; import io.grpc.stub.StreamObserver; import org.slf4j.Logger; diff --git a/java/osx/broker/src/main/java/com/osx/broker/grpc/ForwardPushRespSO.java b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/ForwardPushRespSO.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/grpc/ForwardPushRespSO.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/grpc/ForwardPushRespSO.java index 037f4f0a13..b544b736d4 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/grpc/ForwardPushRespSO.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/ForwardPushRespSO.java @@ -15,10 +15,10 @@ */ package com.osx.broker.grpc; +import com.osx.api.context.Context; import com.osx.broker.callback.CompleteCallback; import com.osx.broker.callback.ErrorCallback; import com.osx.broker.util.TransferUtil; -import com.osx.core.context.Context; import com.webank.ai.eggroll.api.networking.proxy.Proxy; import com.webank.eggroll.core.transfer.Transfer; import io.grpc.stub.StreamObserver; diff --git a/java/osx/broker/src/main/java/com/osx/broker/grpc/PcpGrpcService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/PcpGrpcService.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/grpc/PcpGrpcService.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/grpc/PcpGrpcService.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/grpc/ProxyGrpcService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/ProxyGrpcService.java similarity index 78% rename from java/osx/broker/src/main/java/com/osx/broker/grpc/ProxyGrpcService.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/grpc/ProxyGrpcService.java index 0ebb3f1b0c..47f0d02e31 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/grpc/ProxyGrpcService.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/ProxyGrpcService.java @@ -15,10 +15,13 @@ */ package com.osx.broker.grpc; +import com.osx.broker.interceptor.RouterInterceptor; +import com.osx.broker.interceptor.UnaryCallHandleInterceptor; import com.osx.broker.service.PushService; import com.osx.broker.service.UnaryCallService; import com.osx.broker.util.ContextUtil; -import com.osx.core.context.Context; +import com.osx.api.constants.Protocol; +import com.osx.core.context.FateContext; import com.osx.core.service.InboundPackage; import com.osx.core.service.OutboundPackage; import com.webank.ai.eggroll.api.networking.proxy.DataTransferServiceGrpc; @@ -32,21 +35,22 @@ public class ProxyGrpcService extends DataTransferServiceGrpc.DataTransferServic Logger logger = LoggerFactory.getLogger(ProxyGrpcService.class); UnaryCallService unaryCallService; PushService pushService; - public ProxyGrpcService(PushService pushService, - UnaryCallService unaryCallService + public ProxyGrpcService( ) { - this.pushService = pushService; - this.unaryCallService = unaryCallService; + this.pushService = new PushService(); + this.unaryCallService =new UnaryCallService(); + unaryCallService .addPreProcessor(new UnaryCallHandleInterceptor()). + addPreProcessor(new RouterInterceptor()); + } public io.grpc.stub.StreamObserver push( io.grpc.stub.StreamObserver responseObserver) { try { - Context context = ContextUtil.buildContext(); - InboundPackage data = new InboundPackage<>(); - PushRequestDataWrap pushRequestDataWrap = new PushRequestDataWrap(); - pushRequestDataWrap.setStreamObserver(responseObserver); - data.setBody(pushRequestDataWrap); + FateContext context = ContextUtil.buildFateContext(Protocol.grpc); + context.setNeedPrintFlowLog(false); + InboundPackage data = new InboundPackage<>(); + data.setBody(responseObserver); OutboundPackage outboundPackage = pushService.service(context, data); return outboundPackage.getData(); } catch (Exception e) { @@ -58,7 +62,7 @@ public io.grpc.stub.StreamObserver responseObserver) { - Context context = ContextUtil.buildContext(); + FateContext context = ContextUtil.buildFateContext(Protocol.grpc); InboundPackage data = new InboundPackage<>(); data.setBody(request); context.setDataSize(request.getSerializedSize()); diff --git a/java/osx/broker/src/main/java/com/osx/broker/grpc/PullRequestDataWrap.java b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/PullRequestDataWrap.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/grpc/PullRequestDataWrap.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/grpc/PullRequestDataWrap.java diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/PushRequestDataWrap.java b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/PushRequestDataWrap.java new file mode 100644 index 0000000000..530f36f096 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/PushRequestDataWrap.java @@ -0,0 +1,40 @@ +///* +// * Copyright 2019 The FATE Authors. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"); +// * you may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +//package com.osx.broker.grpc; +// +//import com.webank.ai.eggroll.api.networking.proxy.Proxy; +//import io.grpc.stub.StreamObserver; +// +//public class PushRequestDataWrap { +// Proxy.Packet packet; +// StreamObserver streamObserver; +// +// public Proxy.Packet getPacket() { +// return packet; +// } +// +// public void setPacket(Proxy.Packet packet) { +// this.packet = packet; +// } +// +// public StreamObserver getStreamObserver() { +// return streamObserver; +// } +// +// public void setStreamObserver(StreamObserver streamObserver) { +// this.streamObserver = streamObserver; +// } +//} diff --git a/java/osx/broker/src/main/java/com/osx/broker/grpc/QueuePushReqStreamObserver.java b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/QueuePushReqStreamObserver.java similarity index 72% rename from java/osx/broker/src/main/java/com/osx/broker/grpc/QueuePushReqStreamObserver.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/grpc/QueuePushReqStreamObserver.java index 9c84a3587e..4b6d4ce064 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/grpc/QueuePushReqStreamObserver.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/QueuePushReqStreamObserver.java @@ -19,19 +19,20 @@ import com.google.common.collect.Maps; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; -import com.osx.broker.ServiceContainer; +import com.osx.api.constants.Protocol; +import com.osx.api.context.Context; +import com.osx.api.router.RouterInfo; import com.osx.broker.eggroll.*; import com.osx.broker.ptp.PtpForwardPushRespSO; +import com.osx.broker.router.RouterService; import com.osx.broker.util.TransferUtil; import com.osx.core.config.MetaInfo; -import com.osx.core.constant.ActionType; -import com.osx.core.constant.Dict; -import com.osx.core.constant.TransferStatus; -import com.osx.core.context.Context; +import com.osx.core.constant.*; +import com.osx.core.context.FateContext; import com.osx.core.exceptions.*; import com.osx.core.frame.GrpcConnectionFactory; +import com.osx.core.ptp.SourceMethod; import com.osx.core.ptp.TargetMethod; -import com.osx.core.router.RouterInfo; import com.osx.core.utils.FlowLogUtil; import com.osx.core.utils.ToStringUtils; import com.webank.ai.eggroll.api.networking.proxy.DataTransferServiceGrpc; @@ -58,7 +59,7 @@ public class QueuePushReqStreamObserver implements StreamObserver static public ConcurrentHashMap queueIdMap = new ConcurrentHashMap<>(); static AtomicInteger seq = new AtomicInteger(0); Logger logger = LoggerFactory.getLogger(QueuePushReqStreamObserver.class); - Context context; + FateContext context; ErRollSiteHeader rsHeader = null; TransferStatus transferStatus = TransferStatus.INIT; CountDownLatch finishLatch = new CountDownLatch(1); @@ -73,16 +74,20 @@ public class QueuePushReqStreamObserver implements StreamObserver private Class backRespSOClass; private String transferId; private Integer queueId; + private RouterService routerService; - public QueuePushReqStreamObserver(Context context, StreamObserver backRespSO, + public QueuePushReqStreamObserver(Context context,RouterService routerService, StreamObserver backRespSO, Class backRespSOClass ) { + this.context =(FateContext) context; + this.routerService = routerService; this.backRespSOClass = backRespSOClass; this.backRespSO = backRespSO; - this.context = context.subContext(); - this.context.setNeedPrintFlowLog(true); + //this.context = context.subContext(); + //this.context.setNeedPrintFlowLog(true); this.context.setServiceName("pushTransfer"); + } public StreamObserver getForwardPushReqSO() { @@ -95,12 +100,12 @@ public void setForwardPushReqSO(StreamObserver forwardPushReqSO) { public void init(Proxy.Packet packet) throws Exception { + TransferUtil.assableContextFromProxyPacket(context,packet); Proxy.Metadata metadata = packet.getHeader(); - String desPartyId = metadata.getDst().getPartyId(); - String srcPartyId = metadata.getSrc().getPartyId(); + String desPartyId = context.getDesPartyId(); + String srcPartyId = context.getSrcPartyId(); ByteString encodedRollSiteHeader = metadata.getExt(); rsHeader = ErRollSiteHeader.parseFromPb(Transfer.RollSiteHeader.parseFrom(encodedRollSiteHeader)); - Integer partitionId = rsHeader.getPartitionId(); brokerTag = "putBatch-" + rsHeader.getRsKey("#", "__rsk") + "-" + partitionId; context.setSessionId(rsHeader.getRollSiteSessionId()); @@ -113,10 +118,9 @@ public void init(Proxy.Packet packet) throws Exception { * 检查目的地是否为自己 */ if (!isDst) { - routerInfo = ServiceContainer.fateRouterService.route(packet); + routerInfo =routerService.route(context.getSrcPartyId(),context.getSrcComponent(),context.getDesPartyId(),context.getDesComponent()); if (routerInfo != null) { this.transferId = routerInfo.getResource(); - } else { throw new NoRouterInfoException("no router"); } @@ -129,48 +133,51 @@ public void init(Proxy.Packet packet) throws Exception { context.setRouterInfo(routerInfo); context.setSrcPartyId(routerInfo.getSourcePartyId()); context.setDesPartyId(routerInfo.getDesPartyId()); - ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(context.getRouterInfo(),true); - if (TransferUtil.isOldVersionFate(routerInfo.getVersion())) { - DataTransferServiceGrpc.DataTransferServiceStub stub = DataTransferServiceGrpc.newStub(managedChannel); - ForwardPushRespSO forwardPushRespSO = new ForwardPushRespSO(context, backRespSO,backRespSOClass, () -> { - finishLatch.countDown(); - }, (t) -> { - finishLatch.countDown(); - }); - forwardPushReqSO = stub.push(forwardPushRespSO); - } else { - PtpForwardPushRespSO ptpForwardPushRespSO = new PtpForwardPushRespSO(context, backRespSO, backRespSOClass, () -> { - finishLatch.countDown(); - }, (t) -> { - finishLatch.countDown(); - }); - PrivateTransferProtocolGrpc.PrivateTransferProtocolStub stub = PrivateTransferProtocolGrpc.newStub(managedChannel); - - StreamObserver ptpForwardPushReqSO = stub.transport(ptpForwardPushRespSO); - - forwardPushReqSO = new StreamObserver() { - @Override - public void onNext(Proxy.Packet packet) { - Osx.Inbound inbound = TransferUtil.buildInboundFromPushingPacket(packet, TargetMethod.PUSH.name()); - ptpForwardPushReqSO.onNext(inbound); - } - @Override - public void onError(Throwable throwable) { - ptpForwardPushReqSO.onError(throwable); - } - - @Override - public void onCompleted() { - ptpForwardPushReqSO.onCompleted(); - } - }; + if (routerInfo.getProtocol().equals(Protocol.http)) { + //由本方发起的传输且使用队列替代流式传输,需要在本地建立接受应答的队列, + forwardPushReqSO = QueueStreamBuilder.createStreamFromOrigin(context, backRespSO, Proxy.Packet.parser(), + routerInfo, srcPartyId, desPartyId, rsHeader.getRollSiteSessionId(),finishLatch); + } else { + ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(context.getRouterInfo(), true); + if (TransferUtil.isOldVersionFate(routerInfo.getVersion())) { + DataTransferServiceGrpc.DataTransferServiceStub stub = DataTransferServiceGrpc.newStub(managedChannel); + ForwardPushRespSO forwardPushRespSO = new ForwardPushRespSO(context, backRespSO, backRespSOClass, () -> { + finishLatch.countDown(); + }, (t) -> { + finishLatch.countDown(); + }); + forwardPushReqSO = stub.push(forwardPushRespSO); + } else { + PtpForwardPushRespSO ptpForwardPushRespSO = new PtpForwardPushRespSO(context, backRespSO, backRespSOClass, () -> { + finishLatch.countDown(); + }, (t) -> { + finishLatch.countDown(); + }); + PrivateTransferProtocolGrpc.PrivateTransferProtocolStub stub = PrivateTransferProtocolGrpc.newStub(managedChannel); + StreamObserver ptpForwardPushReqSO = stub.transport(ptpForwardPushRespSO); + forwardPushReqSO = new StreamObserver() { + @Override + public void onNext(Proxy.Packet packet) { + Osx.Inbound inbound = TransferUtil.buildInboundFromPushingPacket(packet, MetaInfo.PROPERTY_FATE_TECH_PROVIDER,TargetMethod.PUSH.name(), SourceMethod.PUSH.name()).build(); + ptpForwardPushReqSO.onNext(inbound); + } + + @Override + public void onError(Throwable throwable) { + ptpForwardPushReqSO.onError(throwable); + } + + @Override + public void onCompleted() { + ptpForwardPushReqSO.onCompleted(); + } + }; + } } } transferStatus = TransferStatus.TRANSFERING; - - } private void initEggroll(Proxy.Packet firstRequest) { @@ -191,7 +198,7 @@ private void initEggroll(Proxy.Packet firstRequest) { logger.error("get session error ", e); } if (!SessionStatus.ACTIVE.name().equals(session.getErSessionMeta().getStatus())) { - IllegalStateException error = new IllegalStateException("session=${sessionId} with illegal status. expected=${SessionStatus.ACTIVE}, actual=${session.sessionMeta.status}"); + SessionInitException error = new SessionInitException("eggroll session "+sessionId+" invalid status : "+session.getErSessionMeta().getStatus()); onError(error); throw error; } @@ -242,8 +249,8 @@ private void initEggroll(Proxy.Packet firstRequest) { ErTask erTask = ErTask.parseFromPb(taskMeta); long now = System.currentTimeMillis(); return erTask; - } catch (InvalidProtocolBufferException e) { - e.printStackTrace(); + } catch (InvalidProtocolBufferException igore) { + } return null; }); @@ -310,28 +317,11 @@ public void onError(Throwable t) { * 2.销毁队列 */ if (isDst) { - //transferQueue.onError(t); - - putBatchSinkPushReqSO.onError(t); } else { - -// if(MetaInfo.PROPERTY_USE_QUEUE_MODEL){ -// if(transferQueue!=null){ -// AbstractServiceAdaptor.ExceptionInfo exceptionInfo = new AbstractServiceAdaptor.ExceptionInfo(); -// exceptionInfo.setMessage(t.getMessage()); -// exceptionInfo.setThrowable(t); -// MessageExtBrokerInner messageExtBrokerInner = MessageDecoder.buildMessageExtBrokerInner(transferId,exceptionInfo.toString().getBytes(StandardCharsets.UTF_8), -// queueId,MessageFlag.ERROR,routerInfo.getSourcePartyId(),routerInfo.getDesPartyId()); -// transferQueue.putMessage(messageExtBrokerInner); -// } -// }else - - { if (forwardPushReqSO != null) { forwardPushReqSO.onError(t); } - } } @@ -342,26 +332,11 @@ public void onCompleted() { logger.info("transferId {} receive completed", transferId); if (isDst) { -// if(transferQueue!=null) { -// transferQueue.setWriteOver(true); -// } if (putBatchSinkPushReqSO != null) { putBatchSinkPushReqSO.onCompleted(); } } else { - if (forwardPushReqSO != null) { - -// if(MetaInfo.PROPERTY_USE_QUEUE_MODEL){ -// /** -// * 由pushConsumer去通知,因为要保证顺序,保证之前的数据传递完,所以只能放在队列最后串行执行 -// */ -// MessageExtBrokerInner messageExtBrokerInner = MessageDecoder.buildMessageExtBrokerInner(transferId,null,queueId,MessageFlag.COMPELETED, -// routerInfo.getSourcePartyId(),routerInfo.getDesPartyId()); -// PutMessageResult putMessageResult = transferQueue.putMessage(messageExtBrokerInner); -// }else - - { forwardPushReqSO.onCompleted(); try { if (!finishLatch.await(MetaInfo.PROPERTY_GRPC_ONCOMPLETED_WAIT_TIMEOUT, TimeUnit.SECONDS)) { @@ -373,12 +348,7 @@ public void onCompleted() { needPrintFlow = false; } } - } -// if(needPrintFlow){ -// context.setActionType("push"); -// context.printFlowLog(); -// } - logger.info("receive completed !!!!"); + } } diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/QueueStreamBuilder.java b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/QueueStreamBuilder.java new file mode 100644 index 0000000000..140423988c --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/QueueStreamBuilder.java @@ -0,0 +1,125 @@ +package com.osx.broker.grpc; + +import com.google.protobuf.*; +import com.osx.api.router.RouterInfo; +import com.osx.broker.ServiceContainer; +import com.osx.broker.constants.MessageFlag; +import com.osx.broker.consumer.SourceGrpcEventHandler; +import com.osx.broker.eggroll.PushEventHandler; +import com.osx.broker.queue.CreateQueueResult; +import com.osx.broker.queue.TransferQueue; +import com.osx.broker.util.TransferUtil; +import com.osx.core.config.MetaInfo; +import com.osx.core.constant.ActionType; +import com.osx.core.constant.Dict; +import com.osx.core.constant.StatusCode; + +import com.osx.core.context.FateContext; +import com.osx.core.exceptions.ExceptionInfo; +import com.osx.core.exceptions.RemoteRpcException; + +import com.osx.core.frame.GrpcConnectionFactory; +import com.osx.core.ptp.TargetMethod; +import com.osx.core.utils.JsonUtil; +import com.webank.ai.eggroll.api.networking.proxy.Proxy; +import io.grpc.ManagedChannel; +import io.grpc.stub.StreamObserver; +import org.ppc.ptp.Osx; +import org.ppc.ptp.PrivateTransferProtocolGrpc; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.nio.charset.StandardCharsets; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; + + +public class QueueStreamBuilder { + + + ConcurrentHashMap backRegister = new ConcurrentHashMap() ; + + + + /** + * 在流的开端调用 + * @param respStreamObserver + * @param parser + * @param srcPartyId + * @param desPartyId + * @param sessionId + * @return + */ + + private static AtomicInteger count= new AtomicInteger(0); + + private static Logger logger = LoggerFactory.getLogger(QueueStreamBuilder.class); + public static StreamObserver createStreamFromOrigin(FateContext context , + StreamObserver respStreamObserver, + Parser parser, + RouterInfo routerInfo, + String srcPartyId, + String desPartyId, + String sessionId, + CountDownLatch countDownLatch + ){ + + //String uuid = UUID.randomUUID().toString(); + int temp = count.addAndGet(1); + long now = System.currentTimeMillis(); + //srcPartyId+"_"+desPartyId + String backTopic = Dict.STREAM_BACK_TOPIC_PREFIX +"_"+now+ "_"+sessionId+"_"+temp; + String sendTopic = Dict.STREAM_SEND_TOPIC_PREFIX +"_"+now+"_"+"sessionId"+"_"+temp; + context.setTopic(sendTopic); + context.setActionType(ActionType.MSG_REDIRECT.getAlias()); + CreateQueueResult createQueueResult = ServiceContainer.transferQueueManager.createNewQueue(backTopic, sessionId, true); + if (createQueueResult.getTransferQueue() == null) { + throw new RemoteRpcException("create queue error"); + } + TransferQueue answerQueue = createQueueResult.getTransferQueue(); + ServiceContainer.consumerManager.createEventDrivenConsumer(backTopic,new SourceGrpcEventHandler(respStreamObserver,parser)); + StreamObserver forwardPushReqSO = new StreamObserver() { + + @Override + public void onNext(AbstractMessage message) { + context.setMessageFlag(MessageFlag.SENDMSG.name()); + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(MetaInfo.PROPERTY_FATE_TECH_PROVIDER,srcPartyId,desPartyId,TargetMethod.PRODUCE_MSG.name(), sendTopic,MessageFlag.SENDMSG,sessionId,message.toByteArray()); + Osx.Outbound outbound = TransferUtil.redirect(context,inboundBuilder.build(),routerInfo,true); + TransferUtil.checkResponse(outbound); + } + + @Override + public void onError(Throwable throwable) { + context.setMessageFlag(MessageFlag.ERROR.name()); + ExceptionInfo exceptionInfo = new ExceptionInfo(); + exceptionInfo.setMessage(throwable.getMessage()); + String errorData = JsonUtil.object2Json(exceptionInfo); + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(MetaInfo.PROPERTY_FATE_TECH_PROVIDER,srcPartyId,desPartyId,TargetMethod.PRODUCE_MSG.name(), + sendTopic,MessageFlag.ERROR,sessionId,errorData.getBytes(StandardCharsets.UTF_8)) + .putMetadata(Osx.Metadata.MessageFlag.name(), MessageFlag.ERROR.name()); + Osx.Outbound outbound = TransferUtil.redirect(context,inboundBuilder.build(),routerInfo,true); + TransferUtil.checkResponse(outbound); + countDownLatch.countDown(); + } + + @Override + public void onCompleted() { + context.setMessageFlag(MessageFlag.COMPELETED.name()); + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(MetaInfo.PROPERTY_FATE_TECH_PROVIDER,srcPartyId,desPartyId,TargetMethod.PRODUCE_MSG.name(), + sendTopic,MessageFlag.COMPELETED,sessionId,"completed".getBytes(StandardCharsets.UTF_8)) + .putMetadata(Osx.Metadata.MessageFlag.name(), MessageFlag.COMPELETED.name()); + Osx.Outbound outbound = TransferUtil.redirect(context,inboundBuilder.build(),routerInfo,true); + + TransferUtil.checkResponse(outbound); + countDownLatch.countDown(); + } + }; + return forwardPushReqSO; + + }; + + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/grpc/ServiceExceptionHandler.java b/java/osx/osx-broker/src/main/java/com/osx/broker/grpc/ServiceExceptionHandler.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/grpc/ServiceExceptionHandler.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/grpc/ServiceExceptionHandler.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/http/DispatchServlet.java b/java/osx/osx-broker/src/main/java/com/osx/broker/http/DispatchServlet.java similarity index 69% rename from java/osx/broker/src/main/java/com/osx/broker/http/DispatchServlet.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/http/DispatchServlet.java index 2e42728670..cba8b644a0 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/http/DispatchServlet.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/http/DispatchServlet.java @@ -18,9 +18,7 @@ import com.osx.broker.ServiceContainer; import com.osx.core.constant.PtpHttpHeader; import com.osx.core.provider.TechProvider; -import com.osx.tech.provider.TechProviderRegister; import org.apache.commons.lang3.StringUtils; -import org.eclipse.jetty.http.HttpHeader; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,52 +31,49 @@ public class DispatchServlet extends HttpServlet { Logger logger = LoggerFactory.getLogger(DispatchServlet.class); + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + //处理get请求 String protocol = req.getProtocol(); if (!protocol.endsWith("1.1")) { resp.sendError(405, "http.method_get_not_supported"); } - String techProviderCode =req.getHeader(PtpHttpHeader.TechProviderCode); - if(StringUtils.isNotEmpty(techProviderCode)){ + String techProviderCode = req.getHeader(PtpHttpHeader.TechProviderCode); + if (StringUtils.isNotEmpty(techProviderCode)) { TechProvider techProvider = ServiceContainer.techProviderRegister.select(techProviderCode); - if(techProvider!=null) { + if (techProvider != null) { techProvider.processHttpInvoke(req, resp); - }else{ - resp.sendError(404,"tech-provider-code invalid"); + } else { + resp.sendError(404, "tech-provider-code invalid"); } - }else{ - resp.sendError(404,"tech-provider-code invalid"); + } else { + resp.sendError(404, "tech-provider-code invalid"); } - String requestUri =req.getRequestURI(); - logger.info("receive request uri {}",requestUri); + String requestUri = req.getRequestURI(); + logger.info("receive request uri {}", requestUri); } protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { - String requestUri =req.getRequestURI(); - logger.info("receive request uri {}",requestUri); + //处理post请求 + String requestUri = req.getRequestURI(); + //logger.info("receive request uri {}",requestUri); String protocol = req.getProtocol(); if (!protocol.endsWith("1.1")) { resp.sendError(405, "http.method_get_not_supported"); } - String techProviderCode =req.getHeader(PtpHttpHeader.TechProviderCode); - if(StringUtils.isNotEmpty(techProviderCode)){ + String techProviderCode = req.getHeader(PtpHttpHeader.TechProviderCode); + if (StringUtils.isNotEmpty(techProviderCode)) { TechProvider techProvider = ServiceContainer.techProviderRegister.select(techProviderCode); - if(techProvider!=null) { + if (techProvider != null) { techProvider.processHttpInvoke(req, resp); - }else{ - resp.sendError(404,"tech-provider-code invalid"); + } else { + resp.sendError(404, "tech-provider-code invalid"); } - }else{ - resp.sendError(404,"tech-provider-code invalid"); + } else { + resp.sendError(404, "tech-provider-code invalid"); } - - - - - } - } diff --git a/java/osx/broker/src/main/java/com/osx/broker/http/HttpClientPool.java b/java/osx/osx-broker/src/main/java/com/osx/broker/http/HttpClientPool.java similarity index 76% rename from java/osx/broker/src/main/java/com/osx/broker/http/HttpClientPool.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/http/HttpClientPool.java index 3dc5f24bfe..42d62fe424 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/http/HttpClientPool.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/http/HttpClientPool.java @@ -22,6 +22,7 @@ import com.osx.core.constant.Dict; import com.osx.core.constant.PtpHttpHeader; import com.osx.core.utils.JsonUtil; +import org.apache.commons.lang3.ObjectUtils; import org.apache.http.Header; import org.apache.http.HttpEntity; import org.apache.http.client.config.RequestConfig; @@ -58,14 +59,27 @@ public class HttpClientPool { private static final Logger logger = LoggerFactory.getLogger(HttpClientPool.class); private static PoolingHttpClientConnectionManager poolConnManager; - private static RequestConfig requestConfig; private static CloseableHttpClient httpClient; - private static void config(HttpRequestBase httpRequestBase, Map headers) { + static void config(HttpRequestBase httpRequestBase, Map headers) { + Integer reqTimeout = null; + Integer connectionTimeout = null; + Integer socketTimeout = null; + + if (MetaInfo.PROPERTY_HTTP_CLIENT_METHOD_CONFIG_MAP != null) { + Map methodConfig = MetaInfo.PROPERTY_HTTP_CLIENT_METHOD_CONFIG_MAP.get(headers.get(PtpHttpHeader.SourceMethod)); + if (methodConfig != null) { + reqTimeout = methodConfig.get(Dict.METHOD_CONFIG_REQ_TIMEOUT); + connectionTimeout = methodConfig.get(Dict.METHOD_CONFIG_CONNECTION_TIMEOUT); + socketTimeout = methodConfig.get(Dict.METHOD_CONFIG_SOCKET_TIMEOUT); + + } + } + RequestConfig requestConfig = RequestConfig.custom() - .setConnectionRequestTimeout(MetaInfo.HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT) - .setConnectTimeout(MetaInfo.HTTP_CLIENT_CONFIG_CONN_TIME_OUT) - .setSocketTimeout(MetaInfo.HTTP_CLIENT_CONFIG_SOCK_TIME_OUT).build(); + .setConnectionRequestTimeout(ObjectUtils.firstNonNull(reqTimeout, MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT)) + .setConnectTimeout(ObjectUtils.firstNonNull(connectionTimeout, MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_TIME_OUT)) + .setSocketTimeout(ObjectUtils.firstNonNull(socketTimeout, MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_SOCK_TIME_OUT)).build(); httpRequestBase.addHeader(Dict.CONTENT_TYPE, Dict.CONTENT_TYPE_JSON_UTF8); if (headers != null) { headers.forEach((key, value) -> { @@ -85,14 +99,8 @@ public static void initPool() { Dict.HTTPS, sslsf).build(); poolConnManager = new PoolingHttpClientConnectionManager( socketFactoryRegistry); - poolConnManager.setMaxTotal(MetaInfo.HTTP_CLIENT_INIT_POOL_MAX_TOTAL); - poolConnManager.setDefaultMaxPerRoute(MetaInfo.HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE); - int socketTimeout = MetaInfo.HTTP_CLIENT_INIT_POOL_SOCK_TIME_OUT; - int connectTimeout = MetaInfo.HTTP_CLIENT_INIT_POOL_CONN_TIME_OUT; - int connectionRequestTimeout = MetaInfo.HTTP_CLIENT_INIT_POOL_CONN_REQ_TIME_OUT; - requestConfig = RequestConfig.custom().setConnectionRequestTimeout( - connectionRequestTimeout).setSocketTimeout(socketTimeout).setConnectTimeout( - connectTimeout).build(); + poolConnManager.setMaxTotal(MetaInfo.PROPERTY_HTTP_CLIENT_INIT_POOL_MAX_TOTAL); + poolConnManager.setDefaultMaxPerRoute(MetaInfo.PROPERTY_HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE); httpClient = createConnection(); } catch (NoSuchAlgorithmException | KeyStoreException | KeyManagementException ex) { logger.error("init http client pool failed:", ex); @@ -103,16 +111,21 @@ public static CloseableHttpClient getConnection() { } public static CloseableHttpClient createConnection() { + RequestConfig requestConfig = RequestConfig.custom() + .setConnectionRequestTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT) + .setConnectTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_TIME_OUT) + .setSocketTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_SOCK_TIME_OUT).build(); CloseableHttpClient httpClient = HttpClients.custom() .setConnectionManager(poolConnManager) .setDefaultRequestConfig(requestConfig) .evictExpiredConnections() - .evictIdleConnections(5, TimeUnit.SECONDS) + .evictIdleConnections(MetaInfo.PROPERTY_HTTP_CLIENT_MAX_IDLE_TIME, TimeUnit.SECONDS) .setRetryHandler(new DefaultHttpRequestRetryHandler(0, false)) .build(); return httpClient; } public static Osx.Outbound sendPtpPost(String url, byte[] body, Map headers) { + HttpPost httpPost = new HttpPost(url); config(httpPost, headers); if(body!=null) { @@ -145,10 +158,8 @@ public static String sendGet(String url, Map headers) { private static String getResponse(HttpRequestBase request) { CloseableHttpResponse response = null; try { - response = httpClient.execute(request, - HttpClientContext.create()); + response = httpClient.execute(request, HttpClientContext.create()); HttpEntity entity = response.getEntity(); - String result = EntityUtils.toString(entity, Dict.CHARSET_UTF8); EntityUtils.consume(entity); return result; @@ -166,15 +177,12 @@ private static String getResponse(HttpRequestBase request) { } } - - private static Osx.Outbound getPtpHttpResponse(HttpRequestBase request) { Osx.Outbound.Builder outboundBuilder = Osx.Outbound.newBuilder(); CloseableHttpResponse response = null; try { - response = httpClient.execute(request, - HttpClientContext.create()); + response = httpClient.execute(request, HttpClientContext.create()); HttpEntity entity = response.getEntity(); byte[] payload = EntityUtils.toByteArray(entity); Header[] headers = response.getAllHeaders(); @@ -187,8 +195,11 @@ private static Osx.Outbound getPtpHttpResponse(HttpRequestBase request) { } if(payload!=null) outboundBuilder.setPayload(ByteString.copyFrom(payload)); - if(headMap.get(PtpHttpHeader.ReturnCode)!=null) + if(headMap.get(PtpHttpHeader.ReturnCode)!=null){ outboundBuilder.setCode(headMap.get(PtpHttpHeader.ReturnCode)); + }else{ + logger.error("========kaideng test ,http respose has no return code {}",headers); + }; if(headMap.get(PtpHttpHeader.ReturnMessage)!=null) outboundBuilder.setMessage(headMap.get(PtpHttpHeader.ReturnMessage)); @@ -196,6 +207,7 @@ private static Osx.Outbound getPtpHttpResponse(HttpRequestBase request) { return outboundBuilder.build(); } catch (IOException ex) { logger.error("get http response failed:", ex); + ex.printStackTrace(); return null; } finally { try { @@ -211,9 +223,9 @@ private static Osx.Outbound getPtpHttpResponse(HttpRequestBase request) { public static String transferPost(String url, Map requestData) { HttpPost httpPost = new HttpPost(url); RequestConfig requestConfig = RequestConfig.custom() - .setConnectionRequestTimeout(MetaInfo.HTTP_CLIENT_TRAN_CONN_REQ_TIME_OUT) - .setConnectTimeout(MetaInfo.HTTP_CLIENT_TRAN_CONN_TIME_OUT) - .setSocketTimeout(MetaInfo.HTTP_CLIENT_TRAN_SOCK_TIME_OUT).build(); + .setConnectionRequestTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT) + .setConnectTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_TIME_OUT) + .setSocketTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_SOCK_TIME_OUT).build(); httpPost.addHeader(Dict.CONTENT_TYPE, Dict.CONTENT_TYPE_JSON_UTF8); httpPost.setConfig(requestConfig); StringEntity stringEntity = new StringEntity(JsonUtil.object2Json(requestData), Dict.CHARSET_UTF8); diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/http/HttpsClientPool.java b/java/osx/osx-broker/src/main/java/com/osx/broker/http/HttpsClientPool.java new file mode 100644 index 0000000000..b884553526 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/http/HttpsClientPool.java @@ -0,0 +1,205 @@ +package com.osx.broker.http; + +import com.google.common.collect.Maps; +import com.google.protobuf.ByteString; +import com.osx.core.config.MetaInfo; +import com.osx.core.constant.Dict; +import com.osx.core.constant.PtpHttpHeader; +import com.osx.core.utils.OSXCertUtils; +import com.osx.core.utils.OsxX509TrustManager; +import org.apache.http.Header; +import org.apache.http.HttpEntity; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.http.config.Registry; +import org.apache.http.config.RegistryBuilder; +import org.apache.http.conn.socket.ConnectionSocketFactory; +import org.apache.http.conn.socket.PlainConnectionSocketFactory; +import org.apache.http.conn.ssl.SSLConnectionSocketFactory; +import org.apache.http.conn.ssl.TrustSelfSignedStrategy; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.DefaultHttpRequestRetryHandler; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.impl.conn.PoolingHttpClientConnectionManager; +import org.apache.http.ssl.SSLContextBuilder; +import org.apache.http.util.EntityUtils; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.*; +import java.io.IOException; +import java.security.*; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +public class HttpsClientPool { + private static final Logger logger = LoggerFactory.getLogger(HttpsClientPool.class); + private static final Map httpsClientPool = new HashMap<>(); + + public static CloseableHttpClient getConnection(String caPath, String clientCertPath, String clientKeyPath) throws Exception { + String certKey = buildCertKey(caPath, clientCertPath, clientKeyPath); + CloseableHttpClient httpClient = httpsClientPool.get(certKey); + if (httpClient == null) { + httpClient = createConnection(caPath, clientCertPath, clientKeyPath); + httpsClientPool.put(certKey, httpClient); + } + return httpClient; + } + + private static String buildCertKey(String caPath, String clientCertPath, String clientKeyPath) { + return caPath + "_" + clientCertPath + "_" + clientKeyPath; + } + + public static CloseableHttpClient createConnection(String caPath, String clientCertPath, String clientKeyPath) throws Exception { + RequestConfig requestConfig = RequestConfig.custom() + .setConnectionRequestTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT) + .setConnectTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_TIME_OUT) + .setSocketTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_SOCK_TIME_OUT).build(); + CloseableHttpClient httpClient = null; + try { + SSLContextBuilder builder = new SSLContextBuilder(); + builder.loadTrustMaterial(null, new TrustSelfSignedStrategy()); + SSLConnectionSocketFactory sslsf; + if (MetaInfo.PROPERTY_HTTP_SSL_HOSTNAME_VERIFY) { + sslsf = new SSLConnectionSocketFactory(OSXCertUtils.getSSLContext(caPath, clientCertPath, clientKeyPath)); + } else { + sslsf = new SSLConnectionSocketFactory(OSXCertUtils.getSSLContext(caPath, clientCertPath, clientKeyPath), OsxX509TrustManager.HostnameVerifier2.getInstance()); + } + Registry socketFactoryRegistry = RegistryBuilder.create().register( + Dict.HTTP, PlainConnectionSocketFactory.getSocketFactory()).register( + Dict.HTTPS, sslsf).build(); + PoolingHttpClientConnectionManager poolConnManager = new PoolingHttpClientConnectionManager( + socketFactoryRegistry); + poolConnManager.setMaxTotal(MetaInfo.PROPERTY_HTTP_CLIENT_INIT_POOL_MAX_TOTAL); + poolConnManager.setDefaultMaxPerRoute(MetaInfo.PROPERTY_HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE); + httpClient = HttpClients.custom() + .setSSLSocketFactory(sslsf) + .setConnectionManager(poolConnManager) + .setDefaultRequestConfig(requestConfig) + .evictExpiredConnections() + .evictIdleConnections(MetaInfo.PROPERTY_HTTP_CLIENT_MAX_IDLE_TIME, TimeUnit.SECONDS) + .setRetryHandler(new DefaultHttpRequestRetryHandler(0, false)) + .build(); + } catch (NoSuchAlgorithmException | KeyStoreException | KeyManagementException ex) { + logger.error("init https client pool failed:", ex); + } + return httpClient; + } + + public static Osx.Outbound sendPtpPost(String url, byte[] body, Map headers, String caPath, String clientCertPath, String clientKeyPath) throws Exception { + + HttpPost httpPost = new HttpPost(url); + HttpClientPool.config(httpPost, headers); + if (body != null) { + ByteArrayEntity byteArrayEntity = new ByteArrayEntity(body); + httpPost.setEntity(byteArrayEntity); + } + return getPtpHttpsResponse(httpPost, caPath, clientCertPath, clientKeyPath); + } + + @SuppressWarnings("unused") + public static String sendPost(String url, byte[] body, Map headers, String caPath, String clientCertPath, String clientKeyPath) { + HttpPost httpPost = new HttpPost(url); + HttpClientPool.config(httpPost, headers); + ByteArrayEntity byteArrayEntity = new ByteArrayEntity(body); + httpPost.setEntity(byteArrayEntity); + return getResponse(httpPost, caPath, clientCertPath, clientKeyPath); + } + + public static String get(String url, Map headers, String caPath, String clientCertPath, String clientKeyPath) { + return sendGet(url, headers, caPath, clientCertPath, clientKeyPath); + } + + public static String get(String url, String caPath, String clientCertPath, String clientKeyPath) { + return sendGet(url, null, caPath, clientCertPath, clientKeyPath); + } + + public static String sendGet(String url, Map headers, String caPath, String clientCertPath, String clientKeyPath) { + HttpGet httpGet = new HttpGet(url); + HttpClientPool.config(httpGet, headers); + return getResponse(httpGet, caPath, clientCertPath, clientKeyPath); + } + + private static String getResponse(HttpRequestBase request, String caPath, String clientCertPath, String clientKeyPath) { + CloseableHttpResponse response = null; + try { + response = getConnection(caPath, clientCertPath, clientKeyPath).execute(request, HttpClientContext.create()); + HttpEntity entity = response.getEntity(); + String result = EntityUtils.toString(entity, Dict.CHARSET_UTF8); + EntityUtils.consume(entity); + return result; + } catch (Exception ex) { + logger.error("get https response failed:", ex); + return null; + } finally { + try { + if (response != null) { + response.close(); + } + } catch (IOException ex) { + logger.error("get https response failed:", ex); + } + } + } + + private static Osx.Outbound getPtpHttpsResponse(HttpRequestBase request, String caPath, String clientCertPath, String clientKeyPath) throws Exception { + Osx.Outbound.Builder outboundBuilder = Osx.Outbound.newBuilder(); + CloseableHttpResponse response = null; + try { + response = getConnection(caPath, clientCertPath, clientKeyPath).execute(request, HttpClientContext.create()); + HttpEntity entity = response.getEntity(); + byte[] payload = EntityUtils.toByteArray(entity); + Header[] headers = response.getAllHeaders(); + Map headMap = Maps.newHashMap(); + if (headers != null) { + for (Header temp : headers) { + headMap.put(temp.getName(), temp.getValue()); + } + } + if (payload != null) + outboundBuilder.setPayload(ByteString.copyFrom(payload)); + if (headMap.get(PtpHttpHeader.ReturnCode) != null) + outboundBuilder.setCode(headMap.get(PtpHttpHeader.ReturnCode)); + if (headMap.get(PtpHttpHeader.ReturnMessage) != null) + outboundBuilder.setMessage(headMap.get(PtpHttpHeader.ReturnMessage)); + + EntityUtils.consume(entity); + return outboundBuilder.build(); + } catch (IOException ex) { + logger.error("get https response failed:", ex); + ex.printStackTrace(); + throw ex; + } finally { + try { + if (response != null) { + response.close(); + } + } catch (IOException ex) { + logger.error("get https response failed:", ex); + } + } + } + + @SuppressWarnings("unused") + private static SSLSocketFactory getSslFactory(String caPath, String clientCertPath, String clientKeyPath) throws Exception { + KeyStore keyStore = OSXCertUtils.getKeyStore(caPath, clientCertPath, clientKeyPath); + // Initialize the ssl context object + SSLContext sslContext = SSLContext.getInstance("SSL"); + TrustManager[] tm = {OsxX509TrustManager.getInstance(keyStore)}; + // Load client certificate + KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); + kmf.init(keyStore, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + sslContext.init(kmf.getKeyManagers(), tm, new SecureRandom()); + // Initialize the factory + return sslContext.getSocketFactory(); + } + + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/http/PtpHttpResponse.java b/java/osx/osx-broker/src/main/java/com/osx/broker/http/PtpHttpResponse.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/http/PtpHttpResponse.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/http/PtpHttpResponse.java diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/PcpHandleInterceptor.java b/java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/PcpHandleInterceptor.java new file mode 100644 index 0000000000..ccc13c7fab --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/PcpHandleInterceptor.java @@ -0,0 +1,36 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.osx.broker.interceptor; + +import com.osx.api.context.Context; +import com.osx.core.service.InboundPackage; +import com.osx.core.service.Interceptor; +import com.osx.core.service.OutboundPackage; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static com.osx.broker.util.TransferUtil.assableContextFromInbound; + +public class PcpHandleInterceptor implements Interceptor { + Logger logger = LoggerFactory.getLogger(PcpHandleInterceptor.class); + + @Override + public void doProcess(Context context, InboundPackage inboundPackage, OutboundPackage outboundPackage) { + Osx.Inbound inbound = inboundPackage.getBody(); + assableContextFromInbound(context,inbound); + } +} diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/PushHandleInterceptor.java b/java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/PushHandleInterceptor.java new file mode 100644 index 0000000000..52b6dc3d54 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/PushHandleInterceptor.java @@ -0,0 +1,19 @@ +//package com.osx.broker.interceptor; +// +//import com.osx.broker.grpc.PushRequestDataWrap; +//import com.osx.core.context.Context; +//import com.osx.core.service.InboundPackage; +//import com.osx.core.service.Interceptor; +//import com.webank.ai.eggroll.api.networking.proxy.Proxy; +// +//import static com.osx.broker.util.TransferUtil.assableContextFromProxyPacket; +// +//public class PushHandleInterceptor implements Interceptor { +// +// public void doPreProcess(Context context, InboundPackage inboundPackage) throws Exception { +// PushRequestDataWrap pushRequestDataWrap =inboundPackage.getBody(); +// Proxy.Packet packet = pushRequestDataWrap.getPacket(); +//// assableContextFromProxyPacket(context ,packet); +// } +// +//} diff --git a/java/osx/broker/src/main/java/com/osx/broker/interceptor/RouterInterceptor.java b/java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/RouterInterceptor.java similarity index 64% rename from java/osx/broker/src/main/java/com/osx/broker/interceptor/RouterInterceptor.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/RouterInterceptor.java index 038f508dc8..0938dd249d 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/interceptor/RouterInterceptor.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/RouterInterceptor.java @@ -14,39 +14,38 @@ * limitations under the License. */ package com.osx.broker.interceptor; + +import com.osx.api.context.Context; +import com.osx.api.router.RouterInfo; +import com.osx.broker.ServiceContainer; import com.osx.broker.router.FateRouterService; -import com.osx.core.context.Context; -import com.osx.core.router.RouterInfo; +import com.osx.broker.router.RouterService; import com.osx.core.service.InboundPackage; import com.osx.core.service.Interceptor; import com.osx.core.service.OutboundPackage; -import org.ppc.ptp.Osx; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class RouterInterceptor implements Interceptor { Logger logger = LoggerFactory.getLogger(RouterInterceptor.class); - - public RouterInterceptor(FateRouterService fateRouterService){ + public RouterInterceptor(){ this.fateRouterService = fateRouterService; } FateRouterService fateRouterService; - - @Override - public void doPreProcess(Context context, InboundPackage inboundPackage) throws Exception { - + public void doProcess(Context context, InboundPackage inboundPackage, OutboundPackage outboundPackage) throws Exception { + String routerKey = buildRouterKey(context); + RouterService routerService = ServiceContainer.routerRegister.getRouterService(routerKey); String sourcePartyId = context.getSrcPartyId(); String desPartyId = context.getDesPartyId(); String sourceComponentName = context.getSrcComponent(); String desComponentName = context.getDesComponent(); - RouterInfo routerInfo = fateRouterService.route(sourcePartyId,sourceComponentName,desPartyId,desComponentName); - logger.info("============== {} {} {} {} ============",sourcePartyId,sourceComponentName,desPartyId,desComponentName); - if(logger.isDebugEnabled()) { - logger.debug("RouterInterceptor return {}", routerInfo); - } + RouterInfo routerInfo = routerService.route(sourcePartyId,sourceComponentName,desPartyId,desComponentName); +// logger.info("router===================={} =============={}",routerService,routerInfo); context.setRouterInfo(routerInfo); - + } + private String buildRouterKey (Context context){ + return context.getTechProviderCode(); } } diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/TokenValidatorInterceptor.java b/java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/TokenValidatorInterceptor.java new file mode 100644 index 0000000000..89bf3e19e6 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/TokenValidatorInterceptor.java @@ -0,0 +1,28 @@ +package com.osx.broker.interceptor; + +import com.osx.api.context.Context; +import com.osx.broker.ServiceContainer; +import com.osx.broker.security.TokenValidator; +import com.osx.core.config.MetaInfo; +import com.osx.core.constant.Dict; +import com.osx.core.service.InboundPackage; +import com.osx.core.service.Interceptor; +import com.osx.core.service.OutboundPackage; + +public class TokenValidatorInterceptor implements Interceptor { + + @Override + public void doProcess(Context context, InboundPackage inboundPackage, OutboundPackage outboundPackage) throws Exception { + if (MetaInfo.PROPERTY_OPEN_TOKEN_VALIDATOR) { + TokenValidator tokenValidator = ServiceContainer.tokenValidatorRegister.getTokenValidator(getValidatorKey(context), Dict.DEFAULT); + if (tokenValidator != null) { + tokenValidator.validate(context, context.getToken()); + } + } + } + + private String getValidatorKey(Context context) { + String srcPartyId = context.getSrcPartyId(); + return srcPartyId; + } +} diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/UnaryCallHandleInterceptor.java b/java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/UnaryCallHandleInterceptor.java new file mode 100644 index 0000000000..2780f6b0e1 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/interceptor/UnaryCallHandleInterceptor.java @@ -0,0 +1,19 @@ +package com.osx.broker.interceptor; + + +import com.osx.api.context.Context; +import com.osx.core.service.InboundPackage; +import com.osx.core.service.Interceptor; +import com.osx.core.service.OutboundPackage; +import com.webank.ai.eggroll.api.networking.proxy.Proxy; + +import static com.osx.broker.util.TransferUtil.assableContextFromProxyPacket; + +public class UnaryCallHandleInterceptor implements Interceptor { + + @Override + public void doProcess(Context context, InboundPackage inboundPackage, OutboundPackage outboundPackage) throws Exception { + Proxy.Packet packet = inboundPackage.getBody(); + assableContextFromProxyPacket(context, packet); + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/AllocateMappedFileService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/message/AllocateMappedFileService.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/message/AllocateMappedFileService.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/message/AllocateMappedFileService.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/AppendMessageHandler.java b/java/osx/osx-broker/src/main/java/com/osx/broker/message/AppendMessageHandler.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/message/AppendMessageHandler.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/message/AppendMessageHandler.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/AppendMessageResult.java b/java/osx/osx-broker/src/main/java/com/osx/broker/message/AppendMessageResult.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/message/AppendMessageResult.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/message/AppendMessageResult.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/AppendMessageStatus.java b/java/osx/osx-broker/src/main/java/com/osx/broker/message/AppendMessageStatus.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/message/AppendMessageStatus.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/message/AppendMessageStatus.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/DefaultAppendMessageHandler.java b/java/osx/osx-broker/src/main/java/com/osx/broker/message/DefaultAppendMessageHandler.java similarity index 65% rename from java/osx/broker/src/main/java/com/osx/broker/message/DefaultAppendMessageHandler.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/message/DefaultAppendMessageHandler.java index ec97d5724d..432505e501 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/message/DefaultAppendMessageHandler.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/message/DefaultAppendMessageHandler.java @@ -15,53 +15,35 @@ */ package com.osx.broker.message; +import com.osx.core.utils.JsonUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; public class DefaultAppendMessageHandler implements AppendMessageHandler { // File at the end of the minimum fixed length empty private static final int END_FILE_MIN_BLANK_LENGTH = 4 + 4; - private final ByteBuffer msgIdMemory; - private final ByteBuffer msgIdV6Memory; // Store the message content private final ByteBuffer msgStoreItemMemory; // The maximum length of the message private final int maxMessageSize; - // Build Message Key - private final StringBuilder keyBuilder = new StringBuilder(); - private final StringBuilder msgIdBuilder = new StringBuilder(); Logger log = LoggerFactory.getLogger(DefaultAppendMessageHandler.class); public DefaultAppendMessageHandler(final int size) { - this.msgIdMemory = ByteBuffer.allocate(4 + 4 + 8); - this.msgIdV6Memory = ByteBuffer.allocate(16 + 4 + 8); + this.msgStoreItemMemory = ByteBuffer.allocate(size + END_FILE_MIN_BLANK_LENGTH); this.maxMessageSize = size; } - protected static int calMsgLength(int sysFlag, int srcPartyIdLength, int desPartyIdLength, int bodyLength, int topicLength, int propertiesLength) { - int bornhostLength = (sysFlag & MessageSysFlag.BORNHOST_V6_FLAG) == 0 ? 8 : 20; - int storehostAddressLength = (sysFlag & MessageSysFlag.STOREHOSTADDRESS_V6_FLAG) == 0 ? 8 : 20; final int msgLen = 4 //TOTALSIZE - + 4 //MAGICCODE - + 4 //BODYCRC - + 4 //QUEUEID + 4 //FLAG - //+ 8 //QUEUEOFFSET + 1 + (srcPartyIdLength > 0 ? srcPartyIdLength : 0) - - // + 8 //PHYSICALOFFSET + 1 + (desPartyIdLength > 0 ? desPartyIdLength : 0) + 4 //SYSFLAG + 8 //BORNTIMESTAMP - // + bornhostLength //BORNHOST - // + 8 //STORETIMESTAMP - // + storehostAddressLength //STOREHOSTADDRESS - // + 4 //RECONSUMETIMES - // + 8 //Prepared Transaction Offset + 4 + (bodyLength > 0 ? bodyLength : 0) //BODY + 2 + topicLength //TOPIC + 2 + (propertiesLength > 0 ? propertiesLength : 0) //propertiesLength @@ -81,94 +63,52 @@ public AppendMessageResult doAppend(final long fileFromOffset, final ByteBuffer String msgId = Long.toString(wroteOffset); Long queueOffset = new Long(0); final byte[] propertiesData = - msgInner.getPropertiesString() == null ? null : msgInner.getPropertiesString().getBytes(MessageDecoder.CHARSET_UTF8); - - final int propertiesLength = propertiesData == null ? 0 : propertiesData.length; + msgInner.getProperties()==null ? null : MessageDecoder.messageProperties2String (msgInner.getProperties()).getBytes(StandardCharsets.UTF_8); + final int propertiesLength = propertiesData==null? 0 : propertiesData.length; if (propertiesLength > Short.MAX_VALUE) { log.warn("putMessage message properties length too long. length={}", propertiesData.length); return new AppendMessageResult(AppendMessageStatus.PROPERTIES_SIZE_EXCEEDED); } - final byte[] topicData = msgInner.getTopic().getBytes(MessageDecoder.CHARSET_UTF8); - - final byte[] srcPartyId = - msgInner.getSrcPartyId() == null ? null : msgInner.getSrcPartyId().getBytes(MessageDecoder.CHARSET_UTF8); + final byte[] srcPartyId = msgInner.getSrcPartyId() == null ? null : msgInner.getSrcPartyId().getBytes(MessageDecoder.CHARSET_UTF8); final int srcPartyIdLength = srcPartyId != null ? srcPartyId.length : 0; - - final byte[] desPartyId = - msgInner.getDesPartyId() == null ? null : msgInner.getDesPartyId().getBytes(MessageDecoder.CHARSET_UTF8); + final byte[] desPartyId = msgInner.getDesPartyId() == null ? null : msgInner.getDesPartyId().getBytes(MessageDecoder.CHARSET_UTF8); final int desPartyIdLength = desPartyId != null ? desPartyId.length : 0; - - final int topicLength = topicData.length; - final int bodyLength = msgInner.getBody() == null ? 0 : msgInner.getBody().length; - final int msgLen = calMsgLength(msgInner.getSysFlag(), srcPartyIdLength, desPartyIdLength, bodyLength, topicLength, propertiesLength); - // Exceeds the maximum message if (msgLen > this.maxMessageSize) { return new AppendMessageResult(AppendMessageStatus.MESSAGE_SIZE_EXCEEDED); } - // Determines whether there is sufficient free space if ((msgLen + END_FILE_MIN_BLANK_LENGTH) > maxBlank) { this.resetByteBuffer(this.msgStoreItemMemory, maxBlank); // 1 TOTALSIZE this.msgStoreItemMemory.putInt(maxBlank); -// // 2 MAGICCODE -// this.msgStoreItemMemory.putInt(1111); - // 3 The remaining space may be any value - // Here the length of the specially set maxBlank final long beginTimeMills = System.currentTimeMillis(); byteBuffer.put(this.msgStoreItemMemory.array(), 0, maxBlank); return new AppendMessageResult(AppendMessageStatus.END_OF_FILE, wroteOffset, maxBlank, msgId, msgInner.getStoreTimestamp(), queueOffset, System.currentTimeMillis() - beginTimeMills); } - // Initialization of storage space this.resetByteBuffer(msgStoreItemMemory, msgLen); - // 1 TOTALSIZE this.msgStoreItemMemory.putInt(msgLen); - // log.info("msgLen {}",msgLen); - // 2 MAGICCODE - this.msgStoreItemMemory.putInt(1000); - // 3 BODYCRC - this.msgStoreItemMemory.putInt(msgInner.getBodyCRC()); - // 4 QUEUEID - this.msgStoreItemMemory.putInt(msgInner.getQueueId()); // 5 FLAG this.msgStoreItemMemory.putInt(msgInner.getFlag()); // 6 QUEUEOFFSET - this.msgStoreItemMemory.put((byte) srcPartyIdLength); if (srcPartyId != null) this.msgStoreItemMemory.put(srcPartyId); - this.msgStoreItemMemory.put((byte) desPartyIdLength); if (desPartyId != null) this.msgStoreItemMemory.put(desPartyId); - - // this.msgStoreItemMemory.putLong(fileFromOffset + byteBuffer.position()); // 8 SYSFLAG this.msgStoreItemMemory.putInt(msgInner.getSysFlag()); // 9 BORNTIMESTAMP this.msgStoreItemMemory.putLong(msgInner.getBornTimestamp()); -// // 10 BORNHOST -// this.resetByteBuffer(bornHostHolder, bornHostLength); -// this.msgStoreItemMemory.put(msgInner.getBornHostBytes(bornHostHolder)); -// // 11 STORETIMESTAMP -// this.msgStoreItemMemory.putLong(msgInner.getStoreTimestamp()); -// // 12 STOREHOSTADDRESS -// this.resetByteBuffer(storeHostHolder, storeHostLength); -// this.msgStoreItemMemory.put(msgInner.getStoreHostBytes(storeHostHolder)); -// // 13 RECONSUMETIMES -// this.msgStoreItemMemory.putInt(msgInner.getReconsumeTimes()); -// // 14 Prepared Transaction Offset -// this.msgStoreItemMemory.putLong(msgInner.getPreparedTransactionOffset()); - // 15 BODY this.msgStoreItemMemory.putInt(bodyLength); if (bodyLength > 0) this.msgStoreItemMemory.put(msgInner.getBody()); @@ -177,13 +117,10 @@ public AppendMessageResult doAppend(final long fileFromOffset, final ByteBuffer this.msgStoreItemMemory.put(topicData); // 17 PROPERTIES this.msgStoreItemMemory.putShort((short) propertiesLength); - if (propertiesLength > 0) + if (propertiesLength > 0) { this.msgStoreItemMemory.put(propertiesData); - - final long beginTimeMills = System.currentTimeMillis(); - // Write messages to the queue buffer + } byteBuffer.put(this.msgStoreItemMemory.array(), 0, msgLen); - AppendMessageResult result = new AppendMessageResult(AppendMessageStatus.PUT_OK, wroteOffset, msgLen, msgId, msgInner.getStoreTimestamp(), queueOffset, 0); return result; @@ -193,6 +130,4 @@ private void resetByteBuffer(final ByteBuffer byteBuffer, final int limit) { byteBuffer.flip(); byteBuffer.limit(limit); } - - } \ No newline at end of file diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/Message.java b/java/osx/osx-broker/src/main/java/com/osx/broker/message/Message.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/message/Message.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/message/Message.java diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageDecoder.java b/java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageDecoder.java new file mode 100644 index 0000000000..3d216e4be0 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageDecoder.java @@ -0,0 +1,410 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.osx.broker.message; + +import com.google.common.collect.Maps; +import com.osx.broker.constants.MessageFlag; +import com.osx.broker.util.MessageId; +import com.osx.broker.util.UtilAll; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.*; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class MessageDecoder { + + + static Logger logger = LoggerFactory.getLogger(MessageDecoder.class); + + public final static Charset CHARSET_UTF8 = Charset.forName("UTF-8"); + public final static int MESSAGE_MAGIC_CODE_POSTION = 4; + public final static int MESSAGE_FLAG_POSTION = 16; + public final static int MESSAGE_PHYSIC_OFFSET_POSTION = 28; + // public final static int MESSAGE_STORE_TIMESTAMP_POSTION = 56; + public final static int MESSAGE_MAGIC_CODE = -626843481; + public static final char NAME_VALUE_SEPARATOR = 1; + public static final char PROPERTY_SEPARATOR = 2; + public static final int PHY_POS_POSITION = 4 + 4 + 4 + 4 + 4 + 8; + public static final int QUEUE_OFFSET_POSITION = 4 + 4 + 4 + 4 + 4; + public static final int SYSFLAG_POSITION = 4 + 4 + 4 + 4 + 4 + 8 + 8; + + + public static String createMessageId(final ByteBuffer input, final ByteBuffer addr, final long offset) { + input.flip(); + int msgIDLength = addr.limit() == 8 ? 16 : 28; + input.limit(msgIDLength); + + input.put(addr); + input.putLong(offset); + + return UtilAll.bytes2string(input.array()); + } + + public static MessageExtBrokerInner buildMessageExtBrokerInner(String topic, byte[] body, + String msgCode, MessageFlag flag, String srcPartyId, String desPartyId) { + MessageExtBrokerInner messageExtBrokerInner = new MessageExtBrokerInner(); + messageExtBrokerInner.setBody(body); + messageExtBrokerInner.setTopic(topic); + messageExtBrokerInner.setFlag(flag.getFlag()); + messageExtBrokerInner.setBornTimestamp(System.currentTimeMillis()); + messageExtBrokerInner.setDesPartyId(desPartyId); + messageExtBrokerInner.setSrcPartyId(srcPartyId); + messageExtBrokerInner.setProperties(Maps.newHashMap()); + messageExtBrokerInner.setMsgId(msgCode); + return messageExtBrokerInner; + } + +// public static String createMessageId(SocketAddress socketAddress, long transactionIdhashCode) { +// InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress; +// int msgIDLength = inetSocketAddress.getAddress() instanceof Inet4Address ? 16 : 28; +// ByteBuffer byteBuffer = ByteBuffer.allocate(msgIDLength); +// byteBuffer.put(inetSocketAddress.getAddress().getAddress()); +// byteBuffer.putInt(inetSocketAddress.getPort()); +// byteBuffer.putLong(transactionIdhashCode); +// byteBuffer.flip(); +// return UtilAll.bytes2string(byteBuffer.array()); +// } + +// public static MessageId decodeMessageId(final String msgId) throws UnknownHostException { +// SocketAddress address; +// long offset; +// int ipLength = msgId.length() == 32 ? 4 * 2 : 16 * 2; +// +// byte[] ip = UtilAll.string2bytes(msgId.substring(0, ipLength)); +// byte[] port = UtilAll.string2bytes(msgId.substring(ipLength, ipLength + 8)); +// ByteBuffer bb = ByteBuffer.wrap(port); +// int portInt = bb.getInt(0); +// address = new InetSocketAddress(InetAddress.getByAddress(ip), portInt); +// +// // offset +// byte[] data = UtilAll.string2bytes(msgId.substring(ipLength + 8, ipLength + 8 + 16)); +// bb = ByteBuffer.wrap(data); +// offset = bb.getLong(0); +// +// return new MessageId(address, offset); +// } + + /** + * Just decode properties from msg buffer. + * + * @param byteBuffer msg commit log buffer. + */ +// public static Map decodeProperties(ByteBuffer byteBuffer) { +// int sysFlag = byteBuffer.getInt(SYSFLAG_POSITION); +// int bornhostLength = (sysFlag & MessageSysFlag.BORNHOST_V6_FLAG) == 0 ? 8 : 20; +// int storehostAddressLength = (sysFlag & MessageSysFlag.STOREHOSTADDRESS_V6_FLAG) == 0 ? 8 : 20; +// int bodySizePosition = 4 // 1 TOTALSIZE +// + 4 // 2 MAGICCODE +// + 4 // 3 BODYCRC +// + 4 // 4 QUEUEID +// + 4 // 5 FLAG +// + 8 // 6 QUEUEOFFSET +// + 8 // 7 PHYSICALOFFSET +// + 4 // 8 SYSFLAG +// + 8 // 9 BORNTIMESTAMP +// + bornhostLength // 10 BORNHOST +// + 8 // 11 STORETIMESTAMP +// + storehostAddressLength // 12 STOREHOSTADDRESS +// + 4 // 13 RECONSUMETIMES +// + 8; // 14 Prepared Transaction Offset +// int topicLengthPosition = bodySizePosition + 4 + byteBuffer.getInt(bodySizePosition); +// +// byte topicLength = byteBuffer.get(topicLengthPosition); +// +// short propertiesLength = byteBuffer.getShort(topicLengthPosition + 1 + topicLength); +// +// byteBuffer.position(topicLengthPosition + 1 + topicLength + 2); +// +// if (propertiesLength > 0) { +// byte[] properties = new byte[propertiesLength]; +// byteBuffer.get(properties); +// String propertiesString = new String(properties, CHARSET_UTF8); +// Map map = string2messageProperties(propertiesString); +// return map; +// } +// return null; +// } + + public static MessageExt decode(ByteBuffer byteBuffer) { + return decode(byteBuffer, true, true, false); + } + +// public static MessageExt clientDecode(ByteBuffer byteBuffer, final boolean readBody) { +// return decode(byteBuffer, readBody, true, true); +// } + + public static MessageExt decode(ByteBuffer byteBuffer, final boolean readBody) { + return decode(byteBuffer, readBody, true, false); + } + + public static MessageExt decode( + ByteBuffer byteBuffer, final boolean readBody, final boolean deCompressBody) { + return decode(byteBuffer, readBody, deCompressBody, false); + } + + public static MessageExt decode( + ByteBuffer byteBuffer, final boolean readBody, final boolean deCompressBody, final boolean isClient) { + try { + + MessageExt msgExt= new MessageExt(); + // 1 TOTALSIZE + int storeSize = byteBuffer.getInt(); + msgExt.setStoreSize(storeSize); + +// // 2 MAGICCODE +// byteBuffer.getInt(); +// +// // 3 BODYCRC +// int bodyCRC = byteBuffer.getInt(); +// msgExt.setBodyCRC(bodyCRC); +// +// // 4 QUEUEID +// int queueId = byteBuffer.getInt(); +// msgExt.setQueueId(queueId); + + // 5 FLAG + int flag = byteBuffer.getInt(); + msgExt.setFlag(flag); + + // 6 QUEUEOFFSET + int srcPartyIdLength = byteBuffer.get(); + if (srcPartyIdLength > 0) { + byte[] srcPartyBytes = new byte[srcPartyIdLength]; + byteBuffer.get(srcPartyBytes); + String srcPartyId = new String(srcPartyBytes); + msgExt.setSrcPartyId(srcPartyId); + } + +// long queueOffset = byteBuffer.getLong(); +// msgExt.setQueueOffset(queueOffset); + + // 7 PHYSICALOFFSET +// long physicOffset = byteBuffer.getLong(); +// msgExt.setCommitLogOffset(physicOffset); + + + int desPartyIdLength = byteBuffer.get(); + if (desPartyIdLength > 0) { + byte[] desPartyIdBytes = new byte[desPartyIdLength]; + byteBuffer.get(desPartyIdBytes); + String desPartyId = new String(desPartyIdBytes); + msgExt.setDesPartyId(desPartyId); + } + + + // 8 SYSFLAG + int sysFlag = byteBuffer.getInt(); + msgExt.setSysFlag(sysFlag); + + // 9 BORNTIMESTAMP + long bornTimeStamp = byteBuffer.getLong(); + msgExt.setBornTimestamp(bornTimeStamp); + + + // 15 BODY + int bodyLen = byteBuffer.getInt(); + if (bodyLen > 0) { + if (readBody) { + byte[] body = new byte[bodyLen]; + byteBuffer.get(body); + msgExt.setBody(body); + } else { + byteBuffer.position(byteBuffer.position() + bodyLen); + } + } + + // 16 TOPIC + short topicLen = byteBuffer.getShort(); + byte[] topic = new byte[(int) topicLen]; + byteBuffer.get(topic); + msgExt.setTopic(new String(topic, CHARSET_UTF8)); + + // 17 properties + short propertiesLength = byteBuffer.getShort(); + + if (propertiesLength > 0) { + byte[] properties = new byte[propertiesLength]; + byteBuffer.get(properties); + String propertiesString = new String(properties, CHARSET_UTF8); + Map map = string2messageProperties(propertiesString); + msgExt.setProperties(map); + + } + return msgExt; + } catch (Exception e) { + e.printStackTrace(); + byteBuffer.position(byteBuffer.limit()); + } + + return null; + } + +// public static List decodes(ByteBuffer byteBuffer) { +// return decodes(byteBuffer, true); +// } + +// public static List decodes(ByteBuffer byteBuffer, final boolean readBody) { +// List msgExts = new ArrayList(); +// while (byteBuffer.hasRemaining()) { +// MessageExt msgExt = clientDecode(byteBuffer, readBody); +// if (null != msgExt) { +// msgExts.add(msgExt); +// } else { +// break; +// } +// } +// return msgExts; +// } + + public static String messageProperties2String(Map properties) { + StringBuilder sb = new StringBuilder(); + if (properties != null) { + for (final Map.Entry entry : properties.entrySet()) { + final String name = entry.getKey(); + final String value = entry.getValue(); + + if (value == null) { + continue; + } + sb.append(name); + sb.append(NAME_VALUE_SEPARATOR); + sb.append(value); + sb.append(PROPERTY_SEPARATOR); + } + } + return sb.toString(); + } + + public static Map string2messageProperties(final String properties) { + Map map = new HashMap(); + if (properties != null) { + String[] items = properties.split(String.valueOf(PROPERTY_SEPARATOR)); + for (String i : items) { + String[] nv = i.split(String.valueOf(NAME_VALUE_SEPARATOR)); + if (2 == nv.length) { + map.put(nv[0], nv[1]); + } + } + } + + return map; + } + +// public static byte[] encodeMessage(Message message) { +// //only need flag, body, properties +// byte[] body = message.getBody(); +// int bodyLen = body.length; +// String properties = messageProperties2String(message.getProperties()); +// byte[] propertiesBytes = properties.getBytes(CHARSET_UTF8); +// //note properties length must not more than Short.MAX +// short propertiesLength = (short) propertiesBytes.length; +// int sysFlag = message.getFlag(); +// int storeSize = 4 // 1 TOTALSIZE +// + 4 // 2 MAGICCOD +// + 4 // 3 BODYCRC +// + 4 // 4 FLAG +// + 4 + bodyLen // 4 BODY +// + 2 + propertiesLength; +// ByteBuffer byteBuffer = ByteBuffer.allocate(storeSize); +// // 1 TOTALSIZE +// byteBuffer.putInt(storeSize); +// +// // 2 MAGICCODE +// byteBuffer.putInt(0); +// +// // 3 BODYCRC +// byteBuffer.putInt(0); +// +// // 4 FLAG +// int flag = message.getFlag(); +// byteBuffer.putInt(flag); +// +// // 5 BODY +// byteBuffer.putInt(bodyLen); +// byteBuffer.put(body); +// +// // 6 properties +// byteBuffer.putShort(propertiesLength); +// byteBuffer.put(propertiesBytes); +// +// return byteBuffer.array(); +// } + +// public static Message decodeMessage(ByteBuffer byteBuffer) throws Exception { +// Message message = new Message(); +// +// // 1 TOTALSIZE +// byteBuffer.getInt(); +// +// // 2 MAGICCODE +// byteBuffer.getInt(); +// +// // 3 BODYCRC +// byteBuffer.getInt(); +// +// // 4 FLAG +// int flag = byteBuffer.getInt(); +// message.setFlag(flag); +// +// // 5 BODY +// int bodyLen = byteBuffer.getInt(); +// byte[] body = new byte[bodyLen]; +// byteBuffer.get(body); +// message.setBody(body); +// +// // 6 properties +// short propertiesLen = byteBuffer.getShort(); +// byte[] propertiesBytes = new byte[propertiesLen]; +// byteBuffer.get(propertiesBytes); +// message.setProperties(string2messageProperties(new String(propertiesBytes, CHARSET_UTF8))); +// +// return message; +// } + +// public static byte[] encodeMessages(List messages) { +// //TO DO refactor, accumulate in one buffer, avoid copies +// List encodedMessages = new ArrayList(messages.size()); +// int allSize = 0; +// for (Message message : messages) { +// byte[] tmp = encodeMessage(message); +// encodedMessages.add(tmp); +// allSize += tmp.length; +// } +// byte[] allBytes = new byte[allSize]; +// int pos = 0; +// for (byte[] bytes : encodedMessages) { +// System.arraycopy(bytes, 0, allBytes, pos, bytes.length); +// pos += bytes.length; +// } +// return allBytes; +// } + +// public static List decodeMessages(ByteBuffer byteBuffer) throws Exception { +// //TO DO add a callback for processing, avoid creating lists +// List msgs = new ArrayList(); +// while (byteBuffer.hasRemaining()) { +// Message msg = decodeMessage(byteBuffer); +// msgs.add(msg); +// } +// return msgs; +// } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/MessageExt.java b/java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageExt.java similarity index 99% rename from java/osx/broker/src/main/java/com/osx/broker/message/MessageExt.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageExt.java index 7bd15a4502..339cccb0a7 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/message/MessageExt.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageExt.java @@ -24,16 +24,9 @@ public class MessageExt extends Message { private static final long serialVersionUID = 5720810158625748049L; - private String brokerName; - private int queueId; - private int storeSize; - - // private long queueOffset; - - private int sysFlag; private long bornTimestamp; private SocketAddress bornHost; diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/MessageExtBrokerInner.java b/java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageExtBrokerInner.java similarity index 85% rename from java/osx/broker/src/main/java/com/osx/broker/message/MessageExtBrokerInner.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageExtBrokerInner.java index 5f42c56990..40fd173332 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/message/MessageExtBrokerInner.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageExtBrokerInner.java @@ -17,7 +17,9 @@ public class MessageExtBrokerInner extends MessageExt { private static final long serialVersionUID = 7256001576878700634L; private String propertiesString; - private long tagsCode; + + + public String getPropertiesString() { return propertiesString; @@ -27,11 +29,4 @@ public void setPropertiesString(String propertiesString) { this.propertiesString = propertiesString; } - public long getTagsCode() { - return tagsCode; - } - - public void setTagsCode(long tagsCode) { - this.tagsCode = tagsCode; - } } diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/MessageStoreConfig.java b/java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageStoreConfig.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/message/MessageStoreConfig.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageStoreConfig.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/MessageSysFlag.java b/java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageSysFlag.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/message/MessageSysFlag.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageSysFlag.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/MessageWraper.java b/java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageWraper.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/message/MessageWraper.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/message/MessageWraper.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/SelectMappedBufferResult.java b/java/osx/osx-broker/src/main/java/com/osx/broker/message/SelectMappedBufferResult.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/message/SelectMappedBufferResult.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/message/SelectMappedBufferResult.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/metric/ClusterMetricLeapArray.java b/java/osx/osx-broker/src/main/java/com/osx/broker/metric/ClusterMetricLeapArray.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/metric/ClusterMetricLeapArray.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/metric/ClusterMetricLeapArray.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/AbstractPtpServiceAdaptor.java b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/AbstractPtpServiceAdaptor.java similarity index 84% rename from java/osx/broker/src/main/java/com/osx/broker/ptp/AbstractPtpServiceAdaptor.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/ptp/AbstractPtpServiceAdaptor.java index 1389b63760..762c6134be 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/AbstractPtpServiceAdaptor.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/AbstractPtpServiceAdaptor.java @@ -15,15 +15,17 @@ */ package com.osx.broker.ptp; -import com.osx.core.context.Context; + +import com.osx.core.context.FateContext; import com.osx.core.exceptions.ExceptionInfo; import com.osx.core.service.AbstractServiceAdaptor; import org.ppc.ptp.Osx; -public abstract class AbstractPtpServiceAdaptor extends AbstractServiceAdaptor { +public abstract class AbstractPtpServiceAdaptor extends AbstractServiceAdaptor { @Override - protected Osx.Outbound transformExceptionInfo(Context context, ExceptionInfo exceptionInfo) { + protected Osx.Outbound transformExceptionInfo(FateContext context, ExceptionInfo exceptionInfo) { + Osx.Outbound.Builder builder = Osx.Outbound.newBuilder(); builder.setCode(exceptionInfo.getCode()); builder.setMessage(exceptionInfo.getMessage()); diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpAckService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpAckService.java similarity index 91% rename from java/osx/broker/src/main/java/com/osx/broker/ptp/PtpAckService.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpAckService.java index e18dae12de..d491345e97 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpAckService.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpAckService.java @@ -15,6 +15,7 @@ */ package com.osx.broker.ptp; +import com.osx.api.router.RouterInfo; import com.osx.broker.ServiceContainer; import com.osx.broker.consumer.UnaryConsumer; import com.osx.broker.queue.TransferQueue; @@ -23,14 +24,12 @@ import com.osx.core.constant.ActionType; import com.osx.core.constant.Dict; import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; +import com.osx.core.context.FateContext; import com.osx.core.exceptions.ConsumerNotExistException; import com.osx.core.exceptions.InvalidRedirectInfoException; import com.osx.core.exceptions.TransferQueueNotExistException; -import com.osx.core.router.RouterInfo; import com.osx.core.service.InboundPackage; import org.apache.commons.lang3.StringUtils; - import org.ppc.ptp.Osx; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,15 +40,15 @@ public class PtpAckService extends AbstractPtpServiceAdaptor { Logger logger = LoggerFactory.getLogger(PtpAckService.class); @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { + protected Osx.Outbound doService(FateContext context, InboundPackage data) { context.setActionType(ActionType.LOCAL_ACK.getAlias()); Osx.Inbound inbound = data.getBody(); Osx.Outbound.Builder outboundBuilder = Osx.Outbound.newBuilder(); - String sessionId = context.getSessionId(); String topic = context.getTopic(); - Long offset = context.getRequestMsgIndex(); +// Long offset = context.getRequestMsgIndex(); + Long offset = (Long)context.getData(Dict.REQUEST_INDEX); TransferQueue transferQueue = ServiceContainer.transferQueueManager.getQueue(topic); - /** + /* * 若本地queue不存在,则检查是否在集群中其他节点 */ if (transferQueue == null) { @@ -72,7 +71,7 @@ protected Osx.Outbound doService(Context context, InboundPackage da redirectRouterInfo.setHost(redirectIp); redirectRouterInfo.setPort(redirectPort); //context.setRouterInfo(redirectRouterInfo); - return redirect(context, inbound, redirectRouterInfo, false); + return redirect(context, inbound, redirectRouterInfo,true); } } else { throw new TransferQueueNotExistException(); @@ -80,7 +79,7 @@ protected Osx.Outbound doService(Context context, InboundPackage da } UnaryConsumer unaryConsumer = ServiceContainer.consumerManager.getUnaryConsumer(topic); if (unaryConsumer != null) { - long currentMsgIndex = unaryConsumer.ack(offset); + unaryConsumer.ack(offset); //context.setCurrentMsgIndex(currentMsgIndex); outboundBuilder.setCode(StatusCode.SUCCESS); outboundBuilder.setMessage(Dict.SUCCESS); @@ -88,8 +87,5 @@ protected Osx.Outbound doService(Context context, InboundPackage da } else { throw new ConsumerNotExistException("consumer is not exist"); } - } - - } diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpCancelTransferService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpCancelTransferService.java similarity index 89% rename from java/osx/broker/src/main/java/com/osx/broker/ptp/PtpCancelTransferService.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpCancelTransferService.java index 687071fbfd..3907c268ff 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpCancelTransferService.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpCancelTransferService.java @@ -18,21 +18,22 @@ import com.osx.broker.ServiceContainer; import com.osx.core.constant.Dict; import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; +import com.osx.core.context.FateContext; import com.osx.core.service.InboundPackage; import org.ppc.ptp.Osx; - import java.util.List; public class PtpCancelTransferService extends AbstractPtpServiceAdaptor { public PtpCancelTransferService() { - this.setServiceName("cansel-unary"); + this.setServiceName("cancel-unary"); } + + @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { + protected Osx.Outbound doService(FateContext context, InboundPackage data) { String sessionId = context.getSessionId(); String topic = context.getTopic(); diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpClusterTokenApplyService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpClusterTokenApplyService.java similarity index 90% rename from java/osx/broker/src/main/java/com/osx/broker/ptp/PtpClusterTokenApplyService.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpClusterTokenApplyService.java index 72f3704d11..4c37ac660c 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpClusterTokenApplyService.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpClusterTokenApplyService.java @@ -1,14 +1,14 @@ package com.osx.broker.ptp; import com.google.protobuf.ByteString; +import com.osx.api.context.Context; import com.osx.broker.ServiceContainer; import com.osx.core.config.MetaInfo; import com.osx.core.constant.ActionType; -import com.osx.core.context.Context; +import com.osx.core.context.FateContext; import com.osx.core.exceptions.RemoteRpcException; import com.osx.core.flow.*; import com.osx.core.frame.GrpcConnectionFactory; -import com.osx.core.router.RouterInfo; import com.osx.core.service.InboundPackage; import com.osx.core.token.TokenRequest; import com.osx.core.token.TokenResult; @@ -27,7 +27,7 @@ public class PtpClusterTokenApplyService extends AbstractPtpServiceAdaptor { Logger logger = LoggerFactory.getLogger(PtpClusterTokenApplyService.class); @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { + protected Osx.Outbound doService(FateContext context, InboundPackage data) { context.setActionType(ActionType.CLUSTER_TOKEN_APPLY.getAlias()); Osx.Inbound inbound = data.getBody(); byte[] temp = inbound.getPayload().toByteArray(); diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpClusterTopicApplyService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpClusterTopicApplyService.java new file mode 100644 index 0000000000..9aca39b37d --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpClusterTopicApplyService.java @@ -0,0 +1,60 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.osx.broker.ptp; + +import com.osx.api.context.Context; +import com.osx.broker.ServiceContainer; +import com.osx.core.constant.ActionType; +import com.osx.core.context.FateContext; +import com.osx.core.exceptions.ParameterException; +import com.osx.core.service.InboundPackage; +import org.apache.commons.lang3.StringUtils; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +public class PtpClusterTopicApplyService extends AbstractPtpServiceAdaptor { + Logger logger = LoggerFactory.getLogger(PtpClusterTopicApplyService.class); + @Override + protected Osx.Outbound doService(FateContext context, InboundPackage data) { + try { + context.setActionType(ActionType.TOPIC_APPLY.getAlias()); + Osx.Inbound inbound = data.getBody(); + String topic = inbound.getMetadataMap().get(Osx.Metadata.MessageTopic.name()); + String instanceId = inbound.getMetadataMap().get(Osx.Metadata.InstanceId.name()); + String sessionId = inbound.getMetadataMap().get(Osx.Header.SessionID.name()); + if (StringUtils.isEmpty(topic)) { + throw new ParameterException("topic is null"); + } + if (StringUtils.isEmpty(instanceId)) { + throw new ParameterException("instanceId is null"); + } + if (StringUtils.isEmpty(sessionId)) { + throw new ParameterException("sessionId is null"); + } + context.setTopic(topic); + context.setSessionId(sessionId); + Osx.Outbound outbound = ServiceContainer.transferQueueManager.applyFromMaster(topic, sessionId, instanceId); + logger.info("====================PtpClusterTopicApplyService================{}=====", outbound); + return outbound; + }catch(Exception e){ + e.printStackTrace(); + throw e; + } + } + +} diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpConsumeService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpConsumeService.java new file mode 100644 index 0000000000..18d148ddf9 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpConsumeService.java @@ -0,0 +1,135 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.osx.broker.ptp; + +import com.google.common.base.Preconditions; +import com.osx.api.context.Context; +import com.osx.api.router.RouterInfo; +import com.osx.broker.ServiceContainer; +import com.osx.broker.consumer.UnaryConsumer; +import com.osx.broker.queue.CreateQueueResult; +import com.osx.broker.queue.TransferQueue; +import com.osx.broker.queue.TransferQueueApplyInfo; +import com.osx.broker.util.TransferUtil; +import com.osx.core.config.MetaInfo; +import com.osx.core.constant.ActionType; +import com.osx.core.constant.Dict; +import com.osx.core.constant.StatusCode; + +import com.osx.core.context.FateContext; +import com.osx.core.exceptions.ParameterException; +import com.osx.core.exceptions.TransferQueueNotExistException; +import com.osx.core.frame.GrpcConnectionFactory; +import com.osx.core.service.InboundPackage; +import io.grpc.ManagedChannel; +import io.grpc.stub.StreamObserver; +import org.apache.commons.lang3.StringUtils; +import org.ppc.ptp.Osx; +import org.ppc.ptp.PrivateTransferProtocolGrpc; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.locks.ReentrantLock; + +public class PtpConsumeService extends AbstractPtpServiceAdaptor { + + Logger logger = LoggerFactory.getLogger(PtpConsumeService.class); + public PtpConsumeService() { + this.setServiceName("consume-unary"); + } + @Override + protected Osx.Outbound doService(FateContext context, InboundPackage data) { + + context.setActionType(ActionType.DEFUALT_CONSUME.getAlias()); + Osx.Inbound inbound = data.getBody(); + String topic = context.getTopic(); + TransferQueue transferQueue = ServiceContainer.transferQueueManager.getQueue(topic); + if (transferQueue == null) { + if (MetaInfo.isCluster()) { + TransferQueueApplyInfo transferQueueApplyInfo = ServiceContainer.transferQueueManager.queryGlobleQueue(topic); + if (transferQueueApplyInfo == null) { + throw new TransferQueueNotExistException("topic "+topic+" not found" ); +// CreateQueueResult createQueueResult = ServiceContainer.transferQueueManager.createNewQueue(topic, context.getSessionId(), false); +// if (createQueueResult.getTransferQueue() == null) { +// //重定向 +// Osx.TopicInfo topicInfo = Osx.TopicInfo.newBuilder() +// .setTopic(topic) +// .setCreateTimestamp(System.currentTimeMillis()) +// .setIp(createQueueResult.getRedirectIp()) +// .setPort(createQueueResult.getPort()) +// .build(); +// return TransferUtil.buildResponseInner(StatusCode.TRANSFER_QUEUE_REDIRECT,"NEED REDIRECT",topicInfo.toByteArray()).build(); +// } + } else { + String[] args = transferQueueApplyInfo.getInstanceId().split(":"); + String ip = args[0]; + int port = Integer.parseInt(args[1]); + RouterInfo routerInfo = new RouterInfo(); + routerInfo.setHost(ip); + routerInfo.setPort(port); + return redirect(context, routerInfo, inbound); + } + } else { + /** + * 单机版直接创建队列 + */ + logger.warn("create topic {} by consume request ", topic); + CreateQueueResult createQueueResult = ServiceContainer.transferQueueManager.createNewQueue(topic, context.getSessionId(), true); + if (createQueueResult.getTransferQueue() == null) { + throw new TransferQueueNotExistException(); + } + } + } + StreamObserver streamObserver = (StreamObserver) context.getData(Dict.RESPONSE_STREAM_OBSERVER); + Long offset = (Long) context.getData(Dict.REQUEST_INDEX); + Preconditions.checkArgument(offset != null); + if (offset == null) { + throw new ParameterException("offset is null"); + } + if (offset > 0) { + context.setActionType(ActionType.CUSTOMER_CONSUME.getAlias()); + } + UnaryConsumer consumer = ServiceContainer.consumerManager.getOrCreateUnaryConsumer(topic); + TransferQueue.TransferQueueConsumeResult transferQueueConsumeResult = consumer.consume(context, offset); + context.setReturnCode(transferQueueConsumeResult.getCode()); + if (transferQueueConsumeResult.getCode().equals(StatusCode.CONSUME_NO_MESSAGE)) { + // 由其他扫描线程应答 + if (offset < 0) { + + UnaryConsumer.LongPullingHold longPullingHold = new UnaryConsumer.LongPullingHold(); + longPullingHold.setNeedOffset(offset); + longPullingHold.setStreamObserver(streamObserver); + longPullingHold.setContext(context.subContext()); + String timeOutString = inbound.getMetadataMap().get(Osx.Metadata.Timeout.name()); + if (StringUtils.isNotEmpty(timeOutString)) { + long current = System.currentTimeMillis(); + longPullingHold.setExpireTimestamp(current + Long.valueOf(timeOutString)); + } + consumer.addLongPullingQueue(longPullingHold); + return null; + } + } + Osx.Outbound consumeResponse = TransferUtil.buildResponse(transferQueueConsumeResult.getCode(), "", transferQueueConsumeResult); + return consumeResponse; + + } + private Osx.Outbound redirect(Context context, RouterInfo routerInfo, Osx.Inbound inbound) { + ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(routerInfo,true); + context.setActionType(ActionType.REDIRECT_CONSUME.getAlias()); + PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub stub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); + return stub.invoke(inbound); + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpForwardPushRespSO.java b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpForwardPushRespSO.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/ptp/PtpForwardPushRespSO.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpForwardPushRespSO.java index 5a66afc18f..464b64ec96 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpForwardPushRespSO.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpForwardPushRespSO.java @@ -15,10 +15,11 @@ */ package com.osx.broker.ptp; +import com.osx.api.context.Context; import com.osx.broker.callback.CompleteCallback; import com.osx.broker.callback.ErrorCallback; import com.osx.broker.util.TransferUtil; -import com.osx.core.context.Context; + import com.webank.ai.eggroll.api.networking.proxy.Proxy; import io.grpc.stub.StreamObserver; import org.ppc.ptp.Osx; diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpProduceService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpProduceService.java new file mode 100644 index 0000000000..ae9a370d3d --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpProduceService.java @@ -0,0 +1,208 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.osx.broker.ptp; + +import com.osx.api.router.RouterInfo; +import com.osx.broker.ServiceContainer; +import com.osx.broker.constants.MessageFlag; +import com.osx.broker.message.MessageDecoder; +import com.osx.broker.message.MessageExtBrokerInner; +import com.osx.broker.queue.CreateQueueResult; +import com.osx.broker.queue.PutMessageResult; +import com.osx.broker.queue.PutMessageStatus; +import com.osx.broker.queue.TransferQueue; +import com.osx.broker.util.TransferUtil; +import com.osx.core.config.MetaInfo; +import com.osx.core.constant.ActionType; +import com.osx.core.constant.DeployMode; +import com.osx.core.constant.Dict; +import com.osx.core.constant.StatusCode; +import com.osx.core.context.FateContext; +import com.osx.core.exceptions.*; +import com.osx.core.service.InboundPackage; +import com.osx.core.service.Interceptor; +import com.osx.core.service.OutboundPackage; +import com.osx.core.utils.FlowLogUtil; +import org.apache.commons.lang3.StringUtils; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static com.osx.broker.util.TransferUtil.redirect; + +public class PtpProduceService extends AbstractPtpServiceAdaptor { + + Logger logger = LoggerFactory.getLogger(PtpProduceService.class); + + + public PtpProduceService() { + this.addPostProcessor(new Interceptor() { + @Override + public void doProcess(FateContext context, InboundPackage inboundPackage, OutboundPackage outboundPackage) { + TransferQueue transferQueue = (TransferQueue) context.getData(Dict.TRANSFER_QUEUE); + if (transferQueue != null) { + transferQueue.cacheReceivedMsg(inboundPackage.getBody().getMetadataMap().get(Osx.Metadata.MessageCode.name()), outboundPackage); + } + } + }); + } + + @Override + protected Osx.Outbound doService(FateContext context, InboundPackage data) { + TransferQueue transferQueue ; + String topic = context.getTopic(); + RouterInfo routerInfo = context.getRouterInfo(); + String srcPartyId = context.getSrcPartyId(); + String sessionId = context.getSessionId(); + Osx.Inbound produceRequest = data.getBody(); + if (!MetaInfo.PROPERTY_SELF_PARTY.contains(context.getDesPartyId())) { + //向外转发 + Osx.Outbound response = null; + int tryTime = 0; + context.setActionType(ActionType.MSG_REDIRECT.getAlias()); + boolean usePooled = true; + while (tryTime < MetaInfo.PROPERTY_PRODUCE_MSG_MAX_TRY_TIME) { + tryTime++; + + try { + if (tryTime > 1) { + context.setRetryTime(tryTime); + produceRequest = produceRequest.toBuilder().putMetadata(Osx.Metadata.RetryCount.name(), Integer.toString(tryTime)).build(); + usePooled = false; + } + response = redirect(context, produceRequest, routerInfo,usePooled); + if (response == null) { + continue; + } + break; + } catch (RemoteRpcException e) { + logger.error("redirect retry count {}", tryTime); + if (tryTime == MetaInfo.PROPERTY_PRODUCE_MSG_MAX_TRY_TIME) { + throw e; + }else{ + FlowLogUtil.printFlowLog(context); + } + try { + Thread.sleep(MetaInfo.PROPERTY_PRODUCE_MSG_RETRY_INTERVAL); + } catch (InterruptedException ignore) { + + } + } + } + return response; + } else { + /* + * 本地处理 + */ + if (StringUtils.isEmpty(topic)) { + throw new ParameterException(StatusCode.PARAM_ERROR, "topic is null"); + } + if (StringUtils.isEmpty(sessionId)) { + throw new ParameterException(StatusCode.PARAM_ERROR, "sessionId is null"); + } + int dataSize = produceRequest.getSerializedSize(); + context.setActionType(ActionType.MSG_DOWNLOAD.getAlias()); + context.setRouterInfo(null); + context.setDataSize(dataSize); + transferQueue = ServiceContainer.transferQueueManager.getQueue(topic); + CreateQueueResult createQueueResult = null; + if (transferQueue == null) { + createQueueResult = ServiceContainer.transferQueueManager.createNewQueue(topic, sessionId, false); + if (createQueueResult == null) { + throw new CreateTopicErrorException("create topic " + topic + " error"); + } + transferQueue = createQueueResult.getTransferQueue(); + } + String resource = TransferUtil.buildResource(produceRequest); + + + if (transferQueue != null) { + ServiceContainer.tokenApplyService.applyToken(context, resource, dataSize); + ServiceContainer.flowCounterManager.pass(resource, dataSize); + context.putData(Dict.TRANSFER_QUEUE, transferQueue); + String msgCode = produceRequest.getMetadataMap().get(Osx.Metadata.MessageCode.name()); + String retryCountString = produceRequest.getMetadataMap().get(Osx.Metadata.RetryCount.name()); + //此处为处理重复请求 + if (StringUtils.isNotEmpty(msgCode)) { + if (transferQueue.checkMsgIdDuplicate(msgCode)) {//检查消息是不是已经存在于队列里面 + if (StringUtils.isBlank(retryCountString)) {//重复请求,非重试请求 + Osx.Outbound.Builder outBoundBuilder = Osx.Outbound.newBuilder(); + outBoundBuilder.setCode(StatusCode.SUCCESS); + outBoundBuilder.setMessage(Dict.DUP_MSG); + return outBoundBuilder.build(); + } else { + logger.info("receive retry request , topic {} msgcode {} try count {}", topic, msgCode, retryCountString); + } + OutboundPackage cacheReceivedMsg = transferQueue.getReceivedMsgCache(msgCode); + if (cacheReceivedMsg != null) {//返回上次缓存的结果 + return cacheReceivedMsg.getData(); + } else {//重试请求,但是缓存的结果已经过期 + logger.warn("The cached message has expired , msgCode = {}", msgCode); + Osx.Outbound.Builder outBoundBuilder = Osx.Outbound.newBuilder(); + outBoundBuilder.setCode(StatusCode.SUCCESS); + outBoundBuilder.setMessage(Dict.PROCESSED_MSG); + return outBoundBuilder.build(); + } + } + } + + byte[] msgBytes = produceRequest.getPayload().toByteArray(); + String flag = produceRequest.getMetadataMap().get(Osx.Metadata.MessageFlag.name()); + MessageFlag messageFlag = MessageFlag.SENDMSG; + if (StringUtils.isNotEmpty(flag)) { + messageFlag = MessageFlag.valueOf(flag); + } + context.putData(Dict.MESSAGE_FLAG, messageFlag.name()); + MessageExtBrokerInner messageExtBrokerInner = MessageDecoder.buildMessageExtBrokerInner(topic, msgBytes, msgCode, messageFlag, context.getSrcPartyId(), + context.getDesPartyId()); + messageExtBrokerInner.getProperties().put(Dict.SESSION_ID, sessionId); + messageExtBrokerInner.getProperties().put(Dict.SOURCE_COMPONENT, context.getSrcComponent() != null ? context.getSrcComponent() : ""); + messageExtBrokerInner.getProperties().put(Dict.DES_COMPONENT, context.getDesComponent() != null ? context.getDesComponent() : ""); + PutMessageResult putMessageResult = transferQueue.putMessage(messageExtBrokerInner); + if (putMessageResult.getPutMessageStatus() != PutMessageStatus.PUT_OK) { + throw new PutMessageException("put status " + putMessageResult.getPutMessageStatus()); + } + long logicOffset = putMessageResult.getMsgLogicOffset(); + context.putData(Dict.CURRENT_INDEX, transferQueue.getIndexQueue().getLogicOffset().get()); + Osx.Outbound.Builder outBoundBuilder = Osx.Outbound.newBuilder(); + outBoundBuilder.setCode(StatusCode.SUCCESS); + outBoundBuilder.setMessage(Dict.SUCCESS); + return outBoundBuilder.build(); + } else { + /* + * 集群内转发 + */ + if (MetaInfo.PROPERTY_DEPLOY_MODE.equals(DeployMode.cluster.name())) { + RouterInfo redirectRouterInfo = new RouterInfo(); + String redirectIp = createQueueResult.getRedirectIp(); + int redirectPort = createQueueResult.getPort(); + if (StringUtils.isEmpty(redirectIp) || redirectPort == 0) { + logger.error("invalid redirect info {}:{}", redirectIp, redirectPort); + throw new InvalidRedirectInfoException(); + } + redirectRouterInfo.setHost(redirectIp); + redirectRouterInfo.setPort(redirectPort); + context.putData(Dict.ROUTER_INFO, redirectRouterInfo); + context.setActionType(ActionType.INNER_REDIRECT.getAlias()); + return redirect(context, produceRequest, redirectRouterInfo,true); + } else { + logger.error("create topic {} error", topic); + throw new ProduceMsgExcption(); + } + } + } + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpPushService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpPushService.java similarity index 78% rename from java/osx/broker/src/main/java/com/osx/broker/ptp/PtpPushService.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpPushService.java index 4d724509b3..2881f28202 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpPushService.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpPushService.java @@ -14,11 +14,13 @@ * limitations under the License. */ package com.osx.broker.ptp; +import com.osx.api.context.Context; import com.osx.broker.ServiceContainer; -import com.osx.broker.grpc.PushRequestDataWrap; + import com.osx.broker.grpc.QueuePushReqStreamObserver; import com.osx.broker.util.TransferUtil; -import com.osx.core.context.Context; +import com.osx.core.config.MetaInfo; +import com.osx.core.context.FateContext; import com.osx.core.exceptions.ExceptionInfo; import com.osx.core.ptp.TargetMethod; import com.osx.core.service.AbstractServiceAdaptor; @@ -29,15 +31,17 @@ import org.ppc.ptp.Osx; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class PtpPushService extends AbstractServiceAdaptor { +public class PtpPushService extends AbstractServiceAdaptor { Logger logger = LoggerFactory.getLogger(PtpPushService.class); @Override - protected StreamObserver doService(Context context, InboundPackage data) { + protected StreamObserver doService(FateContext context, InboundPackage data) { StreamObserver responseStreamObserver = data.getBody(); + context.setNeedPrintFlowLog(false); return new StreamObserver() { Logger logger = LoggerFactory.getLogger(PtpPushService.class); - QueuePushReqStreamObserver queuePushReqStreamObserver = new QueuePushReqStreamObserver(context,responseStreamObserver,Osx.Outbound.class); + QueuePushReqStreamObserver queuePushReqStreamObserver = new QueuePushReqStreamObserver(context,ServiceContainer.routerRegister.getRouterService(MetaInfo.PROPERTY_FATE_TECH_PROVIDER), + responseStreamObserver,Osx.Outbound.class); @Override public void onNext(Osx.Inbound inbound) { int dataSize = inbound.getSerializedSize(); @@ -61,7 +65,7 @@ public void onCompleted() { } @Override - protected StreamObserver transformExceptionInfo(Context context, ExceptionInfo exceptionInfo) { + protected StreamObserver transformExceptionInfo(FateContext context, ExceptionInfo exceptionInfo) { return null; } } diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpQueryTransferQueueService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpQueryTransferQueueService.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/ptp/PtpQueryTransferQueueService.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpQueryTransferQueueService.java index 1ce8e0f25f..50207dd0d8 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpQueryTransferQueueService.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpQueryTransferQueueService.java @@ -23,7 +23,7 @@ import com.osx.core.config.MetaInfo; import com.osx.core.constant.ActionType; import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; +import com.osx.core.context.FateContext; import com.osx.core.service.InboundPackage; import com.osx.core.utils.NetUtils; import org.ppc.ptp.Osx; @@ -40,7 +40,7 @@ public PtpQueryTransferQueueService() { } @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { + protected Osx.Outbound doService(FateContext context, InboundPackage data) { Osx.Inbound request = data.getBody(); Osx.Outbound.Builder outboundBuilder = Osx.Outbound.newBuilder(); diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpStreamTestService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpStreamTestService.java new file mode 100644 index 0000000000..94060dfc6c --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpStreamTestService.java @@ -0,0 +1,99 @@ +//package com.osx.broker.ptp; +// +//import com.google.protobuf.Parser; +//import com.osx.broker.ServiceContainer; +//import com.osx.broker.grpc.QueueStreamBuilder; +//import com.osx.broker.grpc.QueuePushReqStreamObserver; +//import com.osx.broker.util.TransferUtil; +//import com.osx.core.constant.TransferStatus; +//import com.osx.core.context.Context; +//import com.osx.core.exceptions.ExceptionInfo; +//import com.osx.core.frame.GrpcConnectionFactory; +//import com.osx.core.router.RouterInfo; +//import com.osx.core.service.AbstractServiceAdaptor; +//import com.osx.core.service.InboundPackage; +//import com.webank.ai.eggroll.api.networking.proxy.Proxy; +//import io.grpc.ManagedChannel; +//import io.grpc.stub.StreamObserver; +//import org.ppc.ptp.Osx; +//import org.ppc.ptp.PrivateTransferProtocolGrpc; +//import org.slf4j.Logger; +//import org.slf4j.LoggerFactory; +// +//public class PtpStreamTestService extends AbstractServiceAdaptor { +// +// Logger logger = LoggerFactory.getLogger(PtpStreamTestService.class); +// @Override +// protected StreamObserver doService(Context context, InboundPackage data) { +// +// return new StreamObserver() { +// TransferStatus transferStatus = TransferStatus.INIT; +// StreamObserver responseStreamObserver = data.getBody(); +// StreamObserver reqSb=null; +// boolean isDes = false; +// +//// private void initDes(Osx.Inbound first){ +//// +//// +//// reqSb = HttpStreamBuilder.buildStream(responseStreamObserver, +//// Osx.Outbound.parser(), +//// GrpcConnectionFactory.createManagedChannel(context.getRouterInfo(),true), +//// context.getSrcPartyId(),context.getDesPartyId(),context.getSessionId()); +//// transferStatus = TransferStatus.TRANSFERING; +//// } +// +// private void initNotDes(Osx.Inbound first){ +// InboundPackage inboundPackage = new InboundPackage(); +// inboundPackage.setBody(first); +// try { +// ServiceContainer.requestHandleInterceptor.doPreProcess(context, inboundPackage); +// ServiceContainer.routerInterceptor.doPreProcess(context,inboundPackage); +// logger.info("init========={}",context.getRouterInfo()); +// }catch (Exception e){ +// e.printStackTrace(); +// } +// logger.info("ppppppppppppppppppp {}",context.getRouterInfo()); +// reqSb = QueueStreamBuilder.createStreamFromOrigin(context,responseStreamObserver, +// Osx.Outbound.parser(), +// context.getRouterInfo(), +// context.getSrcPartyId(), +// context.getDesPartyId(), +// context.getSessionId(),null); +// transferStatus = TransferStatus.TRANSFERING; +// } +// +// +// @Override +// public void onNext(Osx.Inbound inbound) { +// +//// if(isDes) { +//// if (transferStatus == TransferStatus.INIT) { +//// initDes(inbound); +//// } +//// } +//// else{ +// if(transferStatus==TransferStatus.INIT) { +// initNotDes(inbound); +// } +// // } +// +// if (reqSb != null) { +// reqSb.onNext(inbound); +// } +// } +// @Override +// public void onError(Throwable throwable) { +// reqSb.onError(throwable); +// } +// @Override +// public void onCompleted() { +// logger.info("==============onCompleted=============="); +// } +// }; +// } +// +// @Override +// protected StreamObserver transformExceptionInfo(Context context, ExceptionInfo exceptionInfo) { +// return null; +// } +//} diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpUnaryCallService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpUnaryCallService.java similarity index 54% rename from java/osx/broker/src/main/java/com/osx/broker/ptp/PtpUnaryCallService.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpUnaryCallService.java index 60136ab9c3..4a4412d928 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpUnaryCallService.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/ptp/PtpUnaryCallService.java @@ -15,38 +15,26 @@ */ package com.osx.broker.ptp; +import com.osx.api.router.RouterInfo; +import com.osx.broker.util.TransferUtil; import com.osx.core.constant.ActionType; -import com.osx.core.context.Context; -import com.osx.core.exceptions.RemoteRpcException; -import com.osx.core.frame.GrpcConnectionFactory; -import com.osx.core.router.RouterInfo; +import com.osx.core.context.FateContext; import com.osx.core.service.InboundPackage; -import io.grpc.ManagedChannel; import org.ppc.ptp.Osx; -import org.ppc.ptp.PrivateTransferProtocolGrpc; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - public class PtpUnaryCallService extends AbstractPtpServiceAdaptor { Logger logger = LoggerFactory.getLogger(PtpUnaryCallService.class); @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { + protected Osx.Outbound doService(FateContext context, InboundPackage data) { + context.setActionType(ActionType.UNARY_CALL_NEW.getAlias()); RouterInfo routerInfo = context.getRouterInfo(); Osx.Inbound inbound = data.getBody(); - String host = routerInfo.getHost(); - Integer port = routerInfo.getPort(); - ManagedChannel managedChannel=GrpcConnectionFactory.createManagedChannel(routerInfo,true); - PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub blockingStub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); - Osx.Outbound outbound= null; - try { - outbound = blockingStub.invoke(inbound); - }catch(io.grpc.StatusRuntimeException e){ - logger.error("remote rpc error :router info {}",routerInfo); - throw new RemoteRpcException("remote rpc error"); - } + // logger.info("PtpUnaryCallService receive : {}",inbound); + Osx.Outbound outbound = TransferUtil.redirect(context,inbound,routerInfo,true); return outbound; } diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/Consumer.java b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/Consumer.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/queue/Consumer.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/queue/Consumer.java index ce5c7c9f91..f0ad9eadb1 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/Consumer.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/Consumer.java @@ -14,7 +14,9 @@ * limitations under the License. */ package com.osx.broker.queue; -import com.osx.core.context.Context; + + +import com.osx.api.context.Context; public interface Consumer { diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/CreateQueueResult.java b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/CreateQueueResult.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/queue/CreateQueueResult.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/queue/CreateQueueResult.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/MappedFile.java b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/MappedFile.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/queue/MappedFile.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/queue/MappedFile.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/MappedFileQueue.java b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/MappedFileQueue.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/queue/MappedFileQueue.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/queue/MappedFileQueue.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageLock.java b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/PutMessageLock.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageLock.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/queue/PutMessageLock.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageReentrantLock.java b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/PutMessageReentrantLock.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageReentrantLock.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/queue/PutMessageReentrantLock.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageResult.java b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/PutMessageResult.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageResult.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/queue/PutMessageResult.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageStatus.java b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/PutMessageStatus.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageStatus.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/queue/PutMessageStatus.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/ReferenceResource.java b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/ReferenceResource.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/queue/ReferenceResource.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/queue/ReferenceResource.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueue.java b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/TransferQueue.java similarity index 74% rename from java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueue.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/queue/TransferQueue.java index 3bc3605dc3..b469911292 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueue.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/TransferQueue.java @@ -14,29 +14,44 @@ * limitations under the License. */ package com.osx.broker.queue; -import com.osx.broker.ServiceContainer; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.osx.api.context.Context; import com.osx.broker.callback.CompleteCallback; import com.osx.broker.callback.DestoryCallback; import com.osx.broker.callback.ErrorCallback; +import com.osx.broker.callback.MsgEventCallback; import com.osx.broker.message.MessageDecoder; import com.osx.broker.message.MessageExt; import com.osx.broker.message.MessageExtBrokerInner; import com.osx.broker.message.SelectMappedBufferResult; import com.osx.broker.store.IndexQueue; import com.osx.core.config.MetaInfo; +import com.osx.core.constant.Dict; import com.osx.core.constant.StatusCode; import com.osx.core.constant.TransferStatus; -import com.osx.core.context.Context; +import com.osx.core.exceptions.PutMessageException; import com.osx.core.exceptions.TransferQueueInvalidStatusException; import com.osx.core.queue.TranferQueueInfo; +import com.osx.core.service.OutboundPackage; +import org.apache.commons.lang3.StringUtils; +import org.ppc.ptp.Osx; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReferenceArray; + +import static com.osx.core.config.MetaInfo.PROPERTY_TRANSFER_CACHED_MSGID_SIZE; public class TransferQueue { + + AtomicReferenceArray receivedMsgIds = new AtomicReferenceArray<>(PROPERTY_TRANSFER_CACHED_MSGID_SIZE); + private Cache> receivedMsgCache; protected final AtomicInteger wrotePosition = new AtomicInteger(0); Logger logger = LoggerFactory.getLogger(TransferQueue.class); String transferId; @@ -46,7 +61,8 @@ public class TransferQueue { volatile TransferStatus transferStatus = TransferStatus.INIT; List errorCallbacks = new ArrayList<>(); List completeCallbacks = new ArrayList<>(); - List destoryCallbacks = new ArrayList(); + List destoryCallbacks = new ArrayList<>(); + List msgCallbacks = new ArrayList<>(); long createTimestamp; long lastStatusChangeTimestamp; long lastWriteTimestamp; @@ -54,6 +70,7 @@ public class TransferQueue { boolean writeOver = false; IndexQueue indexQueue; TransferQueueManager transferQueueManager; + public TransferQueue(String transferId, TransferQueueManager transferQueueManager, String path) { this.transferId = transferId; this.transferQueueManager = transferQueueManager; @@ -61,7 +78,7 @@ public TransferQueue(String transferId, TransferQueueManager transferQueueManage this.lastStatusChangeTimestamp = this.createTimestamp; this.lastWriteTimestamp = this.createTimestamp; this.indexQueue = new IndexQueue(transferId, path, MetaInfo.PROPERTY_INDEX_MAP_FILE_SIZE); - + initReceivedMsgCache(); } public String getSessionId() { @@ -96,20 +113,42 @@ public void setIndexQueue(IndexQueue indexQueue) { this.indexQueue = indexQueue; } + public synchronized boolean checkMsgIdDuplicate(String msgId) { + for (int i = 0; i < receivedMsgIds.length(); i++) { + String tempMsgId = receivedMsgIds.get(i); + if (msgId.equals(tempMsgId)) { + return true; + } + } + return false; + } + public synchronized PutMessageResult putMessage(final MessageExtBrokerInner msg) { + if (transferStatus == TransferStatus.TRANSFERING) { + String msgId = msg.getMsgId(); this.lastWriteTimestamp = System.currentTimeMillis(); - PutMessageResult putMessageResult = ServiceContainer.messageStore.putMessage(msg); + PutMessageResult putMessageResult = transferQueueManager.messageStore.putMessage(msg); if (putMessageResult.isOk()) { + + int cacheIdx = wrotePosition.addAndGet(1) % MetaInfo.PROPERTY_TRANSFER_CACHED_MSGID_SIZE; + receivedMsgIds.set(cacheIdx, msgId); long beginWriteOffset = putMessageResult.getAppendMessageResult().getWroteOffset(); int size = putMessageResult.getAppendMessageResult().getWroteBytes(); - logger.info("store begin offset {},size {}", beginWriteOffset, size); putMessageResult.setMsgLogicOffset(indexQueue.putMessagePositionInfoWrapper(beginWriteOffset, size)); + + if (this.msgCallbacks.size() > 0) { + this.msgCallbacks.forEach(msgCallback -> { + msgCallback.callback(this, msg); + }); + } } else { - throw new RuntimeException(); + logger.info("topic {} put msg error",transferId); + throw new PutMessageException("topic " + msg.getTopic() + " put message error"); } return putMessageResult; } else { + logger.error("topic {} is not ready",transferId); throw new TransferQueueInvalidStatusException("invalid queue status : " + transferStatus); } } @@ -120,13 +159,15 @@ public TransferQueueConsumeResult consumeOneMessage(Context context, long reques if (transferStatus == TransferStatus.TRANSFERING) { this.lastReadTimestamp = System.currentTimeMillis(); long logicIndex = indexQueue.getLogicOffset().get(); - context.setRequestMsgIndex(requestIndex); - context.setCurrentMsgIndex(logicIndex); + + context.putData(Dict.REQUEST_INDEX, requestIndex); + //context.setCurrentMsgIndex(logicIndex); + context.putData(Dict.CURRENT_INDEX, logicIndex); if (requestIndex <= logicIndex) { SelectMappedBufferResult indexBufferResult = this.indexQueue.getIndexBuffer(requestIndex); if (indexBufferResult != null) { long pyOffset = indexBufferResult.getByteBuffer().getLong(); - SelectMappedBufferResult msgBufferResult = ServiceContainer.messageStore.consumeOneMessage(pyOffset); + SelectMappedBufferResult msgBufferResult = this.transferQueueManager.getMessageStore().consumeOneMessage(pyOffset); transferQueueConsumeResult = new TransferQueueConsumeResult(StatusCode.SUCCESS, msgBufferResult, requestIndex, logicIndex); MessageExt message = MessageDecoder.decode(transferQueueConsumeResult.getSelectMappedBufferResult().getByteBuffer()); transferQueueConsumeResult.setMessage(message); @@ -145,12 +186,12 @@ public TransferQueueConsumeResult consumeOneMessage(Context context, long reques public synchronized void destory() { logger.info("try to destory transfer queue {} ", transferId); this.indexQueue.destroy(); - logger.info("destroy index file"); + logger.info("topic {} destroy index file", transferId); destoryCallbacks.forEach(destoryCallback -> { try { destoryCallback.callback(); } catch (Exception e) { - logger.error("destory call back error", e); + logger.error("topic {} destory call back execute error", transferId, e); } }); } @@ -165,14 +206,13 @@ public void setCreateTimestamp(long createTimestamp) { public synchronized void onCompeleted() { if (transferStatus == TransferStatus.TRANSFERING) { - transferStatus = TransferStatus.FINISH; } completeCallbacks.forEach(completeCallback -> { try { completeCallback.callback(); } catch (Exception e) { - + logger.error("complete call back error", e); } }); } @@ -191,7 +231,7 @@ public synchronized void onError(Throwable throwable) { }); } - public synchronized void registeErrorCallback(ErrorCallback errorCallback) { + public synchronized void registerErrorCallback(ErrorCallback errorCallback) { if (transferStatus == TransferStatus.TRANSFERING) { errorCallbacks.add(errorCallback); } else { @@ -199,20 +239,24 @@ public synchronized void registeErrorCallback(ErrorCallback errorCallback) { } } - public synchronized void registeDestoryCallback(DestoryCallback destoryCallback) { + public synchronized void registerDestoryCallback(DestoryCallback destoryCallback) { if (transferStatus == TransferStatus.TRANSFERING) destoryCallbacks.add(destoryCallback); else throw new TransferQueueInvalidStatusException("status is " + transferStatus); } + public synchronized void registerMsgCallback(List msgCallbacks) { + if (transferStatus == TransferStatus.TRANSFERING) { + this.msgCallbacks.addAll(msgCallbacks); + } else + throw new TransferQueueInvalidStatusException("status is " + transferStatus); + } + public TransferStatus getTransferStatus() { return transferStatus; } - // public void setTransferStatus(TransferStatus transferStatus) { -// this.transferStatus = transferStatus; -// } public AtomicInteger getWrotePosition() { return wrotePosition; } @@ -256,6 +300,27 @@ public void setLastWriteTimestamp(long lastWriteTimestamp) { this.lastWriteTimestamp = lastWriteTimestamp; } + public void cacheReceivedMsg(String msgId, OutboundPackage outboundPackage) { + + if(StringUtils.isNotEmpty(msgId)) + receivedMsgCache.put(msgId, outboundPackage); + } + + public OutboundPackage getReceivedMsgCache(String sessionId) { + + return receivedMsgCache.getIfPresent(sessionId); + } + + private void initReceivedMsgCache() { + if (receivedMsgCache == null) { + CacheBuilder cacheBuilder = CacheBuilder.newBuilder().maximumSize(MetaInfo.PRODUCE_MSG_CACHE_MAX_SIZE); + if (MetaInfo.PRODUCE_MSG_CACHE_TIMEOUT != null && MetaInfo.PRODUCE_MSG_CACHE_TIMEOUT > 0) { + cacheBuilder.expireAfterWrite(MetaInfo.PRODUCE_MSG_CACHE_TIMEOUT, TimeUnit.MILLISECONDS); + } + receivedMsgCache = cacheBuilder.build(); + } + } + public TranferQueueInfo getTransferQueueInfo() { TranferQueueInfo transferQueueInfo = new TranferQueueInfo(); transferQueueInfo.setTransferId(transferId); diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueApplyInfo.java b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/TransferQueueApplyInfo.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueApplyInfo.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/queue/TransferQueueApplyInfo.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueManager.java b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/TransferQueueManager.java similarity index 73% rename from java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueManager.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/queue/TransferQueueManager.java index 0096e38300..3d26bad2a9 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueManager.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/TransferQueueManager.java @@ -14,26 +14,31 @@ * limitations under the License. */ package com.osx.broker.queue; + import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; -import com.google.protobuf.ByteString; +import com.osx.api.constants.Protocol; +import com.osx.api.router.RouterInfo; import com.osx.broker.ServiceContainer; +import com.osx.broker.callback.MsgEventCallback; +import com.osx.broker.consumer.EventDriverRule; +import com.osx.broker.message.AllocateMappedFileService; +import com.osx.broker.store.MessageStore; import com.osx.core.config.MasterInfo; import com.osx.core.config.MetaInfo; import com.osx.core.constant.DeployMode; import com.osx.core.constant.Dict; import com.osx.core.constant.StatusCode; import com.osx.core.constant.TransferStatus; -import com.osx.core.context.Context; import com.osx.core.exceptions.CreateTopicErrorException; import com.osx.core.exceptions.RemoteRpcException; import com.osx.core.frame.GrpcConnectionFactory; import com.osx.core.frame.ServiceThread; import com.osx.core.ptp.TargetMethod; -import com.osx.core.router.RouterInfo; import com.osx.core.utils.JsonUtil; +import com.osx.core.utils.NetUtils; import io.grpc.ManagedChannel; import io.grpc.StatusRuntimeException; import org.apache.commons.lang3.StringUtils; @@ -44,12 +49,12 @@ import org.slf4j.LoggerFactory; import java.io.File; -import java.nio.charset.StandardCharsets; import java.util.*; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; public class TransferQueueManager { @@ -66,13 +71,35 @@ public class TransferQueueManager { volatile Set instanceIds = new HashSet<>(); ConcurrentHashMap transferQueueMap = new ConcurrentHashMap<>(); ConcurrentHashMap> sessionQueueMap = new ConcurrentHashMap<>(); - ConcurrentHashMap transferIdLockMap = new ConcurrentHashMap(); + ConcurrentHashMap transferIdLockMap = new ConcurrentHashMap<>(); + ConcurrentHashMap> msgCallBackRuleMap = new ConcurrentHashMap<>(); + + public MessageStore getMessageStore() { + return messageStore; + } + + public void setMessageStore(MessageStore messageStore) { + this.messageStore = messageStore; + } + + MessageStore messageStore; + AllocateMappedFileService allocateMappedFileService; volatile long transferApplyInfoVersion = -1; + + public MessageStore createMessageStore( + AllocateMappedFileService allocateMappedFileService) { + MessageStore messageStore = new MessageStore(allocateMappedFileService + , MetaInfo.PROPERTY_TRANSFER_FILE_PATH_PRE + File.separator + MetaInfo.INSTANCE_ID + File.separator + "message-store"); + messageStore.start(); + return messageStore; + + } + private ServiceThread cleanTask = new ServiceThread() { @Override public void run() { while (true) { - this.waitForRunning(1000); + this.waitForRunning(MetaInfo.PROPERTY_TRANSFER_QUEUE_CHECK_INTERVAL); checkAndClean(); } } @@ -84,21 +111,25 @@ public String getServiceName() { }; + AllocateMappedFileService createAllocateMappedFileService() { + AllocateMappedFileService allocateMappedFileService = new AllocateMappedFileService(); + allocateMappedFileService.start(); + return allocateMappedFileService; + } + public TransferQueueManager() { + allocateMappedFileService = createAllocateMappedFileService(); + messageStore = createMessageStore(allocateMappedFileService); instanceIds.add(MetaInfo.INSTANCE_ID); if (MetaInfo.isCluster()) { boolean pathExists = ServiceContainer.zkClient.checkExists(ZK_QUEUE_PREFIX); if (!pathExists) { ServiceContainer.zkClient.create(ZK_QUEUE_PREFIX, false); } - List initApplyInfo = ServiceContainer.zkClient.addChildListener(ZK_QUEUE_PREFIX, (path, children) -> { - parseApplyInfo(children); - }); + List initApplyInfo = ServiceContainer.zkClient.addChildListener(ZK_QUEUE_PREFIX, (path, children) -> parseApplyInfo(children)); parseApplyInfo(initApplyInfo); ServiceContainer.zkClient.create(ZK_COMPONENTS_PREFIX + "/" + MetaInfo.INSTANCE_ID, true); - List initInstanceIds = ServiceContainer.zkClient.addChildListener(ZK_COMPONENTS_PREFIX, (path, children) -> { - handleClusterInstanceId(children); - }); + List initInstanceIds = ServiceContainer.zkClient.addChildListener(ZK_COMPONENTS_PREFIX, (path, children) -> handleClusterInstanceId(children)); ServiceContainer.zkClient.addDataListener(MASTER_PATH, (path, data, type) -> { logger.info("master event {} {}", type, data); if (data != null) { @@ -164,8 +195,6 @@ public String getServiceName() { /** * 平衡的策略暂时没有开发 * - * @param instanceId - * @return */ private String doClusterBalance(String transferId, String instanceId, @@ -180,12 +209,7 @@ private void doMasterWork() { transferQueueApplyInfoMap.forEach((k, v) -> { String instanceId = v.getInstanceId(); if (instanceIds.contains(instanceId)) { - Integer count = temp.get(instanceId); - if (count == null) { - temp.put(instanceId, 1); - } else { - temp.put(instanceId, count + 1); - } + temp.merge(instanceId, 1, Integer::sum); ; } }); @@ -195,52 +219,46 @@ private void doMasterWork() { if (transferQueueApplyInfoMap.get(k) == null) { masterQueueApplyInfoMap.remove(k); } - }; + } + ; }); } private MasterInfo parseMasterInfo(String masterContent) { - MasterInfo masterInfo = JsonUtil.json2Object(masterContent, MasterInfo.class); - return masterInfo; + return JsonUtil.json2Object(masterContent, MasterInfo.class); } private void handleClusterInstanceId(List children) { this.instanceIds.clear(); this.instanceIds.addAll(children); - if(logger.isInfoEnabled()) { + if (logger.isInfoEnabled()) { logger.info("instance change : {}", instanceIds); } } private synchronized void parseApplyInfo(List children) { - Set childSet = Sets.newHashSet(children); - Set intersecitonSet = Sets.intersection(transferQueueApplyInfoMap.keySet(), childSet); - Set needAddSet = null; - if (intersecitonSet != null) - needAddSet = Sets.difference(childSet, intersecitonSet); - Set needRemoveSet = Sets.difference(transferQueueApplyInfoMap.keySet(), intersecitonSet); - if(logger.isInfoEnabled()) { + Set childSet = Sets.newHashSet(children); + Set intersectionSet = Sets.intersection(transferQueueApplyInfoMap.keySet(), childSet); + Set needAddSet; + needAddSet = Sets.difference(childSet, intersectionSet); + Set needRemoveSet = Sets.difference(transferQueueApplyInfoMap.keySet(), intersectionSet); + if (logger.isInfoEnabled()) { logger.info("cluster apply info add {} remove {}", needAddSet, needRemoveSet); } - if (needRemoveSet != null) { - needRemoveSet.forEach(k -> { - transferQueueApplyInfoMap.remove(k); - }); - } - if (needAddSet != null) { - needAddSet.forEach(k -> { - try { - String content = ServiceContainer.zkClient.getContent(ZK_QUEUE_PREFIX + "/" + k); - TransferQueueApplyInfo transferQueueApplyInfo = JsonUtil.json2Object(content, TransferQueueApplyInfo.class); - if (transferQueueApplyInfo != null) { - transferQueueApplyInfoMap.put(k, transferQueueApplyInfo); - } - } catch (Exception e) { - logger.error("parse apply info from zk error",e); + needRemoveSet.forEach(k -> transferQueueApplyInfoMap.remove(k)); + needAddSet.forEach(k -> { + try { + String content = ServiceContainer.zkClient.getContent(buildZkPath(k)); + TransferQueueApplyInfo transferQueueApplyInfo = JsonUtil.json2Object(content, TransferQueueApplyInfo.class); + if (transferQueueApplyInfo != null) { + transferQueueApplyInfoMap.put(k, transferQueueApplyInfo); } - }); - } + } catch (Exception e) { + logger.error("parse apply info from zk error", e); + } + }); } + ; public List cleanByParam(String sessionId, String paramTransferId) { @@ -287,26 +305,25 @@ private void destroyInner(TransferQueue transferQueue) { private void checkAndClean() { long now = System.currentTimeMillis(); + logger.info("the total topic size is {}, total session size is {}", transferQueueMap.size(), sessionQueueMap.size()); transferQueueMap.forEach((transferId, transferQueue) -> { try { long lastReadTimestamp = transferQueue.getLastReadTimestamp(); long lastWriteTimestamp = transferQueue.getLastWriteTimestamp(); - long freeTime = now - (lastReadTimestamp > lastWriteTimestamp ? lastReadTimestamp : lastWriteTimestamp); + long freeTime = now - Math.max(lastReadTimestamp, lastWriteTimestamp); if (transferQueue.getTransferStatus() == TransferStatus.ERROR || transferQueue.getTransferStatus() == TransferStatus.FINISH) { destroy(transferId); } - if (freeTime > MetaInfo.PRPPERTY_QUEUE_MAX_FREE_TIME) { - if(logger.isInfoEnabled()) { - logger.info("transfer queue : {} freetime {} need to be destroy", transferId, freeTime); + if (freeTime > MetaInfo.PROPERTY_QUEUE_MAX_FREE_TIME) { + if (logger.isInfoEnabled()) { + logger.info("topic : {} freetime {} need to be destroy", transferId, freeTime); } destroy(transferId); - return; } } catch (Exception igrone) { - + logger.error("transferQueue clean error ", igrone); } }); - } @@ -321,6 +338,7 @@ public List getTransferQueues(List transferIds) { } return result; } + ConcurrentHashMap clusterApplyLockMap = new ConcurrentHashMap(); public synchronized TransferQueueApplyInfo handleClusterApply(String transferId, @@ -333,7 +351,7 @@ public synchronized TransferQueueApplyInfo handleClusterApply(String transferId, } else { long current = System.currentTimeMillis(); TransferQueueApplyInfo newTransferQueueApplyInfo = new TransferQueueApplyInfo(); - String intanceId = doClusterBalance(transferId, instanceId, sessionId); + doClusterBalance(transferId, instanceId, sessionId); newTransferQueueApplyInfo.setTransferId(transferId); newTransferQueueApplyInfo.setInstanceId(instanceId); newTransferQueueApplyInfo.setSessionId(sessionId); @@ -344,18 +362,25 @@ public synchronized TransferQueueApplyInfo handleClusterApply(String transferId, } - public CreateQueueResult createNewQueue(String transferId, String sessionId, boolean forceCreateLocal) { - Preconditions.checkArgument(StringUtils.isNotEmpty(transferId)); - CreateQueueResult createQueueResult = new CreateQueueResult(); + public ReentrantLock getLock(String transferId){ ReentrantLock transferCreateLock = transferIdLockMap.get(transferId); if (transferCreateLock == null) { transferIdLockMap.putIfAbsent(transferId, new ReentrantLock(false)); } transferCreateLock = transferIdLockMap.get(transferId); - transferCreateLock.lock(); - try { + return transferCreateLock; + } + + - boolean exist = this.transferQueueMap.get(transferId) != null ? true : false; + + public CreateQueueResult createNewQueue(String transferId, String sessionId, boolean forceCreateLocal) { + Preconditions.checkArgument(StringUtils.isNotEmpty(transferId)); + CreateQueueResult createQueueResult = new CreateQueueResult(); + ReentrantLock transferCreateLock= getLock(transferId); + try { + transferCreateLock.lock(); + boolean exist = this.transferQueueMap.get(transferId) != null; if (exist) { createQueueResult.setTransferQueue(this.transferQueueMap.get(transferId)); String[] elements = MetaInfo.INSTANCE_ID.split(":"); @@ -364,7 +389,7 @@ public CreateQueueResult createNewQueue(String transferId, String sessionId, boo return createQueueResult; } if (MetaInfo.PROPERTY_DEPLOY_MODE.equals(DeployMode.cluster.name()) && !forceCreateLocal) { - /** + /* * 缓存的集群信息中能够找到,直接返回信息 */ if (this.transferQueueApplyInfoMap.get(transferId) != null) { @@ -378,21 +403,20 @@ public CreateQueueResult createNewQueue(String transferId, String sessionId, boo createQueueResult.setRedirectIp(ip); return createQueueResult; } else { - /** + /* * 这种情况存在于本地已删除,而集群信息未同步更新,可能存在延迟,这时重走申请流程 */ } - }; - - Osx.Outbound applyTopicResponse = this.applyFromMaster(transferId,sessionId,MetaInfo.INSTANCE_ID); + } + Osx.Outbound applyTopicResponse = this.applyFromMaster(transferId, sessionId, MetaInfo.INSTANCE_ID); logger.info("apply topic response {}", applyTopicResponse); if (applyTopicResponse != null) { - /** + /* * 从clustermananger 返回的结果中比对instantceId ,如果为本实例,则在本地建Q */ - String applyInstanceId = applyTopicResponse.getMetadataMap().get(Osx.Metadata.InstanceId.name()); + String applyInstanceId = applyTopicResponse.getMetadataMap().get(Osx.Metadata.InstanceId.name()); if (MetaInfo.INSTANCE_ID.equals(applyInstanceId)) { @@ -403,43 +427,42 @@ public CreateQueueResult createNewQueue(String transferId, String sessionId, boo registerTransferQueue(transferId, sessionId); //createQueueResult = applyFromCluster(transferId,sessionId); } else { - if(applyInstanceId!=null) { + if (applyInstanceId != null) { String[] args = applyInstanceId.split(":"); String ip = args[0]; String portString = args[1]; int grpcPort = Integer.parseInt(portString); createQueueResult.setRedirectIp(ip); createQueueResult.setPort(grpcPort); - }else{ + } else { throw new CreateTopicErrorException("apply topic from master error"); } - }; + } } else { throw new RuntimeException(); } } else { - /** + /* * 单机版部署,直接本地建Q */ createQueueResult.setTransferQueue(localCreate(transferId, sessionId)); - String[] args = MetaInfo.INSTANCE_ID.split(":"); - String ip = args[0]; - String portString = args[1]; - createQueueResult.setPort(Integer.parseInt(portString)); - createQueueResult.setRedirectIp(ip); +// String[] args = MetaInfo.INSTANCE_ID.split("_"); +// String ip = args[0]; +// String portString = args[1]; + + createQueueResult.setPort(MetaInfo.PROPERTY_GRPC_PORT); + createQueueResult.setRedirectIp(NetUtils.getLocalHost()); } return createQueueResult; } finally { transferCreateLock.unlock(); + } } private void registerTransferQueue(String transferId, String sessionId) { - StringBuffer sb = new StringBuffer(); - sb.append(ZK_QUEUE_PREFIX).append("/"); - sb.append(transferId); - String path = sb.toString(); + String path = buildZkPath(transferId); TransferQueueApplyInfo transferQueueApplyInfo = new TransferQueueApplyInfo(); transferQueueApplyInfo.setTransferId(transferId); transferQueueApplyInfo.setSessionId(sessionId); @@ -448,25 +471,26 @@ private void registerTransferQueue(String transferId, String sessionId) { try { ServiceContainer.zkClient.create(path, JsonUtil.object2Json(transferQueueApplyInfo), true); } catch (KeeperException.NodeExistsException e) { - e.printStackTrace(); + logger.error("register path {} to zk error", path); } } + public String buildZkPath(String transferId) { + return ZK_QUEUE_PREFIX + "/" + transferId; + } + private CreateQueueResult applyFromCluster(String transferId, String sessionId) { CreateQueueResult createQueueResult = null; if (MetaInfo.PROPERTY_USE_ZOOKEEPER) { createQueueResult = new CreateQueueResult(); - StringBuffer sb = new StringBuffer(); - sb.append(ZK_QUEUE_PREFIX).append("/"); - sb.append(transferId); - String path = sb.toString(); + String path = buildZkPath(transferId); boolean exist = ServiceContainer.zkClient.checkExists(path); if (exist) { String content = ServiceContainer.zkClient.getContent(path); TransferQueueApplyInfo transferQueueApplyInfo = JsonUtil.json2Object(content, TransferQueueApplyInfo.class); } else { - /** + /* * 如何平均 */ TransferQueueApplyInfo transferQueueApplyInfo = new TransferQueueApplyInfo(); @@ -477,10 +501,11 @@ private CreateQueueResult applyFromCluster(String transferId, String sessionId) try { ServiceContainer.zkClient.create(path, JsonUtil.object2Json(transferQueueApplyInfo), true); } catch (KeeperException.NodeExistsException e) { - e.printStackTrace(); + logger.error("register path {} in zk error", path); } String content = ServiceContainer.zkClient.getContent(path); transferQueueApplyInfo = JsonUtil.json2Object(content, TransferQueueApplyInfo.class); + assert transferQueueApplyInfo != null; if (MetaInfo.INSTANCE_ID.equals(transferQueueApplyInfo.getInstanceId())) { createQueueResult.setTransferQueue(localCreate(transferId, sessionId)); } else { @@ -491,33 +516,35 @@ private CreateQueueResult applyFromCluster(String transferId, String sessionId) } } return createQueueResult; - } - public Osx.Outbound applyFromMaster( String topic,String sessionId,String instanceId) { - if (!isMaster()) { + public Osx.Outbound applyFromMaster(String topic, String sessionId, String instanceId) { - RouterInfo routerInfo = this.getMasterAddress(); + if (!isMaster()) { + RouterInfo routerInfo = this.getMasterAddress(); //context.setRouterInfo(routerInfo); - ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(routerInfo,true); + ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(routerInfo, true); PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub stub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); try { - Osx.Inbound.Builder builder = Osx.Inbound.newBuilder(); - builder.putMetadata(Osx.Metadata.MessageTopic.name(),topic); - builder.putMetadata(Osx.Metadata.InstanceId.name(),instanceId); - builder.putMetadata(Osx.Header.SessionID.name(),sessionId); + Osx.Inbound.Builder builder = Osx.Inbound.newBuilder(); + builder.putMetadata(Osx.Metadata.MessageTopic.name(), topic); + builder.putMetadata(Osx.Metadata.InstanceId.name(), instanceId); + builder.putMetadata(Osx.Header.SessionID.name(), sessionId); + builder.putMetadata(Osx.Header.TechProviderCode.name(), MetaInfo.PROPERTY_FATE_TECH_PROVIDER); + builder.putMetadata(Osx.Metadata.TargetMethod.name(), TargetMethod.APPLY_TOPIC.name()); return stub.invoke(builder.build()); - }catch(StatusRuntimeException e){ - throw new RemoteRpcException("send to "+routerInfo.toKey()+" error"); + } catch (StatusRuntimeException e) { + logger.error("apply topic {} from master error", topic, e); + throw new RemoteRpcException("send to " + routerInfo.toKey() + " error"); } } else { TransferQueueApplyInfo transferQueueApplyInfo = this.handleClusterApply(topic, instanceId, sessionId); Osx.Outbound.Builder outboundBuilder = Osx.Outbound.newBuilder(); - outboundBuilder.getMetadataMap().put(Osx.Metadata.MessageTopic.name(), topic); - outboundBuilder.getMetadataMap().put(Osx.Metadata.InstanceId.name(), instanceId); - outboundBuilder.getMetadataMap().put(Osx.Metadata.Timestamp.name(), Long.toString(transferQueueApplyInfo.getApplyTimestamp())); + outboundBuilder.putMetadata(Osx.Metadata.MessageTopic.name(), topic); + outboundBuilder.putMetadata(Osx.Metadata.InstanceId.name(), instanceId); + outboundBuilder.putMetadata(Osx.Metadata.Timestamp.name(), Long.toString(transferQueueApplyInfo.getApplyTimestamp())); outboundBuilder.setCode(StatusCode.SUCCESS); outboundBuilder.setMessage(Dict.SUCCESS); return outboundBuilder.build(); @@ -530,37 +557,55 @@ private RouterInfo getMasterAddress() { String[] args = MetaInfo.masterInfo.getInstanceId().split(Dict.COLON); routerInfo.setHost(args[0]); routerInfo.setPort(Integer.parseInt(args[1])); + routerInfo.setProtocol(Protocol.grpc); return routerInfo; } private void unRegisterCluster(String transferId) { - logger.info("unRegister transferId {}", transferId); - if (MetaInfo.isCluster()) { - ServiceContainer.zkClient.delete(ZK_QUEUE_PREFIX + "/" + transferId); + + if (MetaInfo.isCluster() && MetaInfo.isCluster()) { + logger.info("unRegister topic {} from zk", transferId); + ServiceContainer.zkClient.delete(buildZkPath(transferId)); } } + private void setMsgCallBack(TransferQueue transferQueue) { + this.msgCallBackRuleMap.forEach((rule, msgCallbacks) -> { - private TransferQueue localCreate(String transferId, String sessionId) { - logger.info("create local topic {}",transferId); - TransferQueue transferQueue = new TransferQueue(transferId, this, MetaInfo.PROPERTY_TRANSFER_FILE_PATH_PRE + File.separator + MetaInfo.INSTANCE_ID); + if (rule.isMatch(transferQueue)) { + // logger.info("rule {} is mactched",rule); + transferQueue.registerMsgCallback(msgCallbacks); + } else { + // logger.info("rule {} is not matched",rule); + } + }); + } + + ; + + + private TransferQueue localCreate(String topic, String sessionId) { + logger.info("create local topic {}", topic); + TransferQueue transferQueue = new TransferQueue(topic, this, MetaInfo.PROPERTY_TRANSFER_FILE_PATH_PRE + File.separator + MetaInfo.INSTANCE_ID); transferQueue.setSessionId(sessionId); transferQueue.start(); - transferQueue.registeDestoryCallback(() -> { - this.transferQueueMap.remove(transferId); + transferQueue.registerDestoryCallback(() -> { + this.transferQueueMap.remove(topic); if (this.sessionQueueMap.get(sessionId) != null) { - this.sessionQueueMap.get(sessionId).remove(transferId); + this.sessionQueueMap.get(sessionId).remove(topic); } + unRegisterCluster(topic); }); - transferQueueMap.put(transferId, transferQueue); + setMsgCallBack(transferQueue); + transferQueueMap.put(topic, transferQueue); sessionQueueMap.putIfAbsent(sessionId, new HashSet<>()); - sessionQueueMap.get(sessionId).add(transferId); + sessionQueueMap.get(sessionId).add(topic); return transferQueue; } - public TransferQueue getQueue(String transferId) { - return transferQueueMap.get(transferId); + public TransferQueue getQueue(String topic) { + return transferQueueMap.get(topic); } public Map getAllLocalQueue() { @@ -568,17 +613,17 @@ public Map getAllLocalQueue() { } - private void destroy(String transferId) { - Preconditions.checkArgument(StringUtils.isNotEmpty(transferId)); - ReentrantLock transferIdLock = this.transferIdLockMap.get(transferId); + private void destroy(String topic) { + Preconditions.checkArgument(StringUtils.isNotEmpty(topic)); + ReentrantLock transferIdLock = this.transferIdLockMap.get(topic); if (transferIdLock != null) { transferIdLock.lock(); } try { - TransferQueue transferQueue = getQueue(transferId); + TransferQueue transferQueue = getQueue(topic); if (transferQueue != null) { destroyInner(transferQueue); - transferIdLockMap.remove(transferId); + transferIdLockMap.remove(topic); } } finally { @@ -592,12 +637,10 @@ private void destroy(String transferId) { public void onError(String transferId, Throwable throwable) { TransferQueue transferQueue = transferQueueMap.get(transferId); if (transferQueue != null) { - /** + /* * 这里需要处理的问题是,当异常发生时,消费者并没有接入,等触发之后才接入 */ - errorCallBackExecutor.execute(() -> { - transferQueue.onError(throwable); - }); + errorCallBackExecutor.execute(() -> transferQueue.onError(throwable)); } this.destroy(transferId); } @@ -617,21 +660,23 @@ public TransferQueueApplyInfo queryGlobleQueue(String transferId) { } public void destroyAll() { - logger.info("prepare to destory {}", transferQueueMap); if (MetaInfo.isCluster()) { try { if (this.isMaster()) { ServiceContainer.zkClient.delete(MASTER_PATH); } ServiceContainer.zkClient.close(); - ; } catch (Exception e) { e.printStackTrace(); } - logger.info("unregister component over"); } this.transferQueueMap.forEach((transferId, transferQueue) -> { transferQueue.destory(); }); } + + + public void addMsgCallBackRule(EventDriverRule rule, List callbacks) { + this.msgCallBackRuleMap.put(rule, callbacks); + } } diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueMonitorService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/queue/TransferQueueMonitorService.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueMonitorService.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/queue/TransferQueueMonitorService.java diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/router/DefaultFateRouterServiceImpl.java b/java/osx/osx-broker/src/main/java/com/osx/broker/router/DefaultFateRouterServiceImpl.java new file mode 100644 index 0000000000..804d49b95f --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/router/DefaultFateRouterServiceImpl.java @@ -0,0 +1,445 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.osx.broker.router; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.google.common.base.Preconditions; +import com.google.common.collect.Maps; +import com.google.protobuf.InvalidProtocolBufferException; +import com.osx.api.constants.Protocol; +import com.osx.api.context.Context; +import com.osx.api.router.RouterInfo; +import com.osx.broker.util.TelnetUtil; +import com.osx.core.config.MetaInfo; +import com.osx.core.constant.Dict; +import com.osx.core.datasource.FileRefreshableDataSource; +import com.osx.core.exceptions.*; +import com.osx.core.flow.PropertyListener; +import com.osx.core.frame.Lifecycle; +import com.osx.core.frame.ServiceThread; +import com.osx.core.service.InboundPackage; +import com.osx.core.utils.FileUtils; +import com.osx.core.utils.JsonUtil; +import com.webank.ai.eggroll.api.networking.proxy.Proxy; +import com.webank.eggroll.core.transfer.Transfer; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileNotFoundException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class DefaultFateRouterServiceImpl implements FateRouterService, Lifecycle { + + private static final String IP = "ip"; + private static final String PORT = "port"; + private static final String URL = "url"; + private static final String USE_SSL = "useSSL"; + private static final String HOSTNAME = "hostname"; + private static final String negotiationType = "negotiationType"; + private static final String certChainFile = "certChainFile"; + private static final String privateKeyFile = "privateKeyFile"; + private static final String caFile = "caFile"; + private static final String DEFAULT = "default"; + private static final String VERSION = "version"; + + //Pattern urlIpPort = Pattern.compile("(\\d+\\.\\d+\\.\\d+\\.\\d+)\\:(\\d+)"); + + Pattern urlIpPortPattern = Pattern.compile("((http|ftp|https)://)((([a-zA-Z0-9._-]+)|([0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}))(([a-zA-Z]{2,6})|(:[0-9]{1,4})?))"); + + Logger logger = LoggerFactory.getLogger(DefaultFateRouterServiceImpl.class); + Map> routerInfoMap = new ConcurrentHashMap>(); + Map>> endPointMap = new ConcurrentHashMap<>(); + FileRefreshableDataSource fileRefreshableDataSource; + + @Override + public RouterInfo route(Proxy.Packet packet) { + Preconditions.checkArgument(packet != null); + RouterInfo routerInfo = null; + Proxy.Metadata metadata = packet.getHeader(); + Transfer.RollSiteHeader rollSiteHeader = null; + String dstPartyId = null; + try { + rollSiteHeader = Transfer.RollSiteHeader.parseFrom(metadata.getExt()); + if (rollSiteHeader != null) { + dstPartyId = rollSiteHeader.getDstPartyId(); + } + } catch (InvalidProtocolBufferException e) { + e.printStackTrace(); + } + if (StringUtils.isEmpty(dstPartyId)) { + dstPartyId = metadata.getDst().getPartyId(); + } + String desRole = metadata.getDst().getRole(); + String srcRole = metadata.getSrc().getRole(); + String srcPartyId = metadata.getSrc().getPartyId(); + routerInfo = this.route(srcPartyId, srcRole, dstPartyId, desRole); + //logger.info("query router info {} to {} {} return {}", srcPartyId, dstPartyId, desRole, routerInfo); + return routerInfo; + } + + private RouterInfo buildRouterInfo(Map endpoint, String srcPartyId, String srcRole, String dstPartyId, String desRole) { + + Preconditions.checkArgument(endpoint != null); + RouterInfo routerInfo = new RouterInfo(); + if (endpoint.get(IP) != null) { + routerInfo.setHost(endpoint.get(IP).toString()); + } + if (endpoint.get(PORT) != null) { + routerInfo.setPort(((Number) endpoint.get(PORT)).intValue()); + } + routerInfo.setDesPartyId(dstPartyId); + routerInfo.setSourcePartyId(srcPartyId); + routerInfo.setVersion(endpoint.get(VERSION) != null ? endpoint.get(VERSION).toString() : null); + routerInfo.setNegotiationType(endpoint.get(negotiationType) != null ? endpoint.get(negotiationType).toString() : ""); + routerInfo.setDesRole(desRole); + Protocol protocol = Protocol.grpc; + if (endpoint.get(Dict.PROTOCOL) != null) { + try { + protocol = Protocol.valueOf(endpoint.get(Dict.PROTOCOL).toString()); + } catch (Exception ignore) { + + } + } + routerInfo.setProtocol(protocol); + routerInfo.setUrl(endpoint.get(Dict.URL) != null ? endpoint.get(Dict.URL).toString() : ""); + routerInfo.setUseSSL(endpoint.get(Dict.USE_SSL) != null && Boolean.parseBoolean(endpoint.get(Dict.USE_SSL).toString())); + routerInfo.setCaFile(endpoint.get(Dict.CA_FILE) != null ? endpoint.get(Dict.CA_FILE).toString() : ""); + routerInfo.setCertChainFile(endpoint.get(Dict.CERT_CHAIN_FILE) != null ? endpoint.get(Dict.CERT_CHAIN_FILE).toString() : ""); + routerInfo.setPrivateKeyFile(endpoint.get(Dict.PRIVATE_KEY_FILE) != null ? endpoint.get(Dict.PRIVATE_KEY_FILE).toString() : ""); + if (routerInfo.getProtocol().equals(Protocol.http)) { + if (StringUtils.isEmpty(routerInfo.getUrl())) { + throw new InvalidRouteInfoException(); + } + } + if (endpoint.get(Dict.IS_CYCLE) != null && (Boolean) endpoint.get(Dict.IS_CYCLE)) { + logger.error("router info {} has a cycle invoke", routerInfo.toKey()); + throw new CycleRouteInfoException("router info has a cycle invoke"); + } + return routerInfo; + } + + public RouterInfo route(String srcPartyId, String srcRole, String dstPartyId, String desRole) { + // logger.info("try to find routerInfo =={}=={}=={}=={}",srcPartyId,srcRole,dstPartyId,desRole); + RouterInfo routerInfo = null; + Map> partyIdMap = this.endPointMap.containsKey(dstPartyId)?this.endPointMap.get(dstPartyId):this.endPointMap.get(DEFAULT); + if (partyIdMap != null) { + if (StringUtils.isNotEmpty(desRole) && partyIdMap.get(desRole) != null) { + List ips = partyIdMap.getOrDefault(desRole, null); + if (ips != null && ips.size() > 0) { + Map endpoint = ips.get((int) (System.currentTimeMillis() % ips.size())); + routerInfo = buildRouterInfo(endpoint, srcPartyId, srcRole, dstPartyId, desRole); + } + } else { + + List ips = partyIdMap.getOrDefault(DEFAULT, null); + if (ips != null && ips.size() > 0) { + Map endpoint = ips.get((int) (System.currentTimeMillis() % ips.size())); + routerInfo = buildRouterInfo(endpoint, srcPartyId, srcRole, dstPartyId, desRole); + } + if (StringUtils.isNotEmpty(desRole)) { + // logger.warn("role {} is not found,return default router info ",desRole); + } + } + } + + return routerInfo; + } + + + Map>> initRouteTable(Map confJson) { + // BasicMeta.Endpoint.Builder endpointBuilder = BasicMeta.Endpoint.newBuilder(); + Map>> newRouteTable = new ConcurrentHashMap<>(); + // loop through coordinator + + confJson.forEach((k, v) -> { + String coordinatorKey = k.toString(); + Map coordinatorValue = (Map) v; + + Map> serviceTable = newRouteTable.get(coordinatorKey); + if (serviceTable == null) { + serviceTable = new ConcurrentHashMap<>(4); + newRouteTable.put(coordinatorKey, serviceTable); + } + // loop through role in coordinator + for (Object roleEntryObject : coordinatorValue.entrySet()) { + Map.Entry roleEntry = (Map.Entry) roleEntryObject; + String roleKey = roleEntry.getKey().toString(); + if (roleKey.equals("createTime") || roleKey.equals("updateTime")) { + continue; + } + List roleValue = (List) roleEntry.getValue(); + + List endpoints = serviceTable.get(roleKey); + if (endpoints == null) { + endpoints = new ArrayList<>(); + serviceTable.put(roleKey, endpoints); + } + // loop through endpoints + for (Object endpointElement : roleValue) { + Map element = Maps.newHashMap(); + Map endpointJson = (Map) endpointElement; + element.putAll(endpointJson); + endpoints.add(element); + } + } + + }); + + return newRouteTable; + } + + @Override + public void init() { + + } + + public void start() { + String currentPath = getRouterTablePath(); + logger.info("load router file {}", currentPath); + File confFile = new File(currentPath); + FileRefreshableDataSource fileRefreshableDataSource = null; + try { + fileRefreshableDataSource = new FileRefreshableDataSource(confFile, (source) -> { + // logger.info("read route_table {}", source); + return source; + }); + fileRefreshableDataSource.getProperty().addListener(new RouterTableListener()); + + } catch (FileNotFoundException e) { + logger.error("router file {} is not found", currentPath); + } + /** + * 检查路由表中是否存在回环,是否能连通 + */ + ServiceThread routerInfoChecker = new ServiceThread() { + + @Override + public void run() { + while (true) { + //Map> partyIdMap = this.endPointMap.get(dstPartyId); + endPointMap.forEach((desPartyId, desPoint) -> { + desPoint.forEach((role, routerElementMap) -> { + routerElementMap.forEach(endPoint -> { + + String ip = null; + int port = 0; + Protocol protocol = Protocol.grpc; + try { + if (endPoint.get(Dict.PROTOCOL) != null) { + try { + protocol = Protocol.valueOf(endPoint.get(Dict.PROTOCOL).toString()); + } catch (Exception e) { + logger.warn("route info {}->{} protocol is invalid , please check route_table.json", desPartyId, role); + } + } + ; + if (endPoint.get(Dict.URL) != null) { + String ipPortString = getIpInfoFromUrl(endPoint.get(Dict.URL).toString()); + if (StringUtils.isNotEmpty(ipPortString)) { + ip = ipPortString.split(Dict.COLON)[0]; + String portString = ipPortString.split(Dict.COLON)[1]; + port = Integer.parseInt(portString); + } + } + if (protocol.equals(Protocol.grpc)) { + if (endPoint.get(IP) != null) { + ip = endPoint.get(IP).toString(); + } + if (endPoint.get(PORT) != null) { + port = ((Number) endPoint.get(PORT)).intValue(); + } + } + //if (!MetaInfo.PROPERTY_SELF_PARTY.contains(desPartyId)) { + + boolean isCycle = checkCycle(ip, port); + if (isCycle) { + logger.warn("route info {}->{}->{}->{} is a cycle , please check route_table.json", desPartyId, role, ip, port); + } + endPoint.put(Dict.IS_CYCLE, isCycle); + //} + checkConnected(desPartyId, role, ip, port); + + } catch (Exception ignore) { + ignore.printStackTrace(); + } + } + ); + }); + } + ); + + this.waitForRunning(60000); + } + } + + @Override + public String getServiceName() { + return "cycle_checker"; + } + }; + routerInfoChecker.start(); + } + + private String getRouterTablePath() { + return MetaInfo.PROPERTY_CONFIG_DIR + "/broker/route_table.json"; + } + + @Override + public void destroy() { + + } + + private void checkConnected(String partyId, String role, String ip, int port) { + + if (MetaInfo.PROPERTY_USE_REMOTE_HEALTH_CHECK) { + if (StringUtils.isNotEmpty(ip)) { + + boolean result = TelnetUtil.tryTelnet(ip, port); + if (!result) { + // logger.warn("route info {}->{}->{}->{} unable to connect , please check route_table.json", partyId, role, ip, port); + + } + } + } + } + + private boolean checkCycle(String ip, int port) { + + boolean cycle = false; + + if(MetaInfo.PROPERTY_OPEN_ROUTE_CYCLE_CHECKER) { + String localIp = MetaInfo.INSTANCE_ID.split(":")[0]; + + if (localIp.equals(ip) || Dict.LOCALHOST.equals(ip) || Dict.LOCALHOST2.equals(ip)) { + if (MetaInfo.PROPERTY_GRPC_PORT == (port)) { + cycle = true; + } + if (MetaInfo.PROPERTY_OPEN_GRPC_TLS_SERVER) { + if (MetaInfo.PROPERTY_GRPC_TLS_PORT == port) { + cycle = true; + } + } + if (MetaInfo.PROPERTY_OPEN_HTTP_SERVER) { + if (MetaInfo.PROPERTY_HTTP_PORT == (port)) { + cycle = true; + } + } + } + } + + return cycle; + } + + + private class RouterTableListener implements PropertyListener { + + @Override + public void configUpdate(String value) { + logger.info("found router_table.json has been changed, update content {}",value); + Map confJson = JsonUtil.json2Object(value, Map.class); + // JsonObject confJson = JsonParser.parseString(value).getAsJsonObject(); + Map content = (Map) confJson.get("route_table"); + endPointMap = initRouteTable(content); + } + + @Override + public void configLoad(String value) { + Map confJson = JsonUtil.json2Object(value, Map.class); + if(confJson!=null){ + + // throw new ConfigErrorException("content of route_table.json is invalid"); + + Map content = (Map) confJson.get("route_table"); + endPointMap = initRouteTable(content); + logger.info("load router config {}", JsonUtil.formatJson(JsonUtil.object2Json(endPointMap))); + + }else{ + logger.error("content of route_table.json is invalid , content is {}",value); + + } + } + } + + + public String getIpInfoFromUrl(String url) { + Matcher m = urlIpPortPattern.matcher(url); + String result = ""; + if (m.find()) { + result = m.group(3); + } + return result; + } + + public boolean saveRouterTable(Context context, InboundPackage data) { + try { + String inboundRouteJson = (String) context.getData("route"); + if (StringUtils.isNotBlank(inboundRouteJson)) { + Map routeMap = JsonUtil.object2Objcet(inboundRouteJson, new TypeReference>() { + }); + Map route_table = (Map) routeMap.get("route_table"); + route_table.forEach((partyId, value) -> { + List routeList = (List) value; + for (RouterInfo routerInfo : routeList) { + routerInfo.setProtocol(StringUtils.isBlank(routerInfo.getProtocol().toString()) ? Protocol.grpc : routerInfo.getProtocol()); + } + }); + inboundRouteJson = JsonUtil.object2Json(routeMap); + } + String routerTablePath = getRouterTablePath(); + File routerTableFile = new File(routerTablePath); + if (!routerTableFile.exists()) { + if (!routerTableFile.getParentFile().exists()) { + if (!routerTableFile.getParentFile().mkdirs()) { + logger.warn("mkdir failed : {}", routerTableFile.getParent()); + return false; + } + } + if (!routerTableFile.createNewFile()) { + logger.warn("create router_table.json failed : {}", routerTableFile.getAbsoluteFile()); + return false; + } + } + return FileUtils.writeStr2ReplaceFileSync(JsonUtil.formatJson(inboundRouteJson), routerTablePath); + } catch (Exception e) { + logger.error("save router table failed ", e); + ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context, e); + context.setReturnCode(exceptionInfo.getCode()); + context.setReturnMsg("save router table failed"); + return false; + } + } + + public static void main(String[] args) { +// System.out.println(MetaInfo.PROPERTY_USER_DIR); +// System.out.println(MetaInfo.PROPERTY_USER_HOME); +// System.out.println(Thread.currentThread().getContextClassLoader().getResource("").getPath()); +// System.out.println(Thread.currentThread().getContextClassLoader().getResource("route_table.json")); +// System.out.println(Thread.currentThread().getContextClassLoader().getResource("flowRule.json")); + DefaultFateRouterServiceImpl defaultFateRouterService = new DefaultFateRouterServiceImpl(); + defaultFateRouterService.getIpInfoFromUrl("http://127.0.0.1:9000/xxxx"); + + + } + + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/router/FateRouterService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/router/FateRouterService.java similarity index 81% rename from java/osx/broker/src/main/java/com/osx/broker/router/FateRouterService.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/router/FateRouterService.java index 630c954f44..7fbad503c5 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/router/FateRouterService.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/router/FateRouterService.java @@ -16,13 +16,10 @@ package com.osx.broker.router; -import com.osx.core.router.RouterInfo; +import com.osx.api.router.RouterInfo; import com.webank.ai.eggroll.api.networking.proxy.Proxy; -public interface FateRouterService { - - RouterInfo route(String srcPartyId, String srcRole, String dstPartyId, String desRole); +public interface FateRouterService extends RouterService{ RouterInfo route(Proxy.Packet packet); - } diff --git a/java/osx/broker/src/main/java/com/osx/broker/router/RemoteRouterDataSource.java b/java/osx/osx-broker/src/main/java/com/osx/broker/router/RemoteRouterDataSource.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/router/RemoteRouterDataSource.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/router/RemoteRouterDataSource.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/router/RouterMetric.java b/java/osx/osx-broker/src/main/java/com/osx/broker/router/RouterMetric.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/router/RouterMetric.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/router/RouterMetric.java diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/router/RouterRegister.java b/java/osx/osx-broker/src/main/java/com/osx/broker/router/RouterRegister.java new file mode 100644 index 0000000000..aff3904efd --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/router/RouterRegister.java @@ -0,0 +1,70 @@ +package com.osx.broker.router; + +import com.osx.core.config.MetaInfo; +import com.osx.core.constant.Dict; +import com.osx.core.frame.Lifecycle; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.InputStream; +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +public class RouterRegister implements Lifecycle { + + Logger logger = LoggerFactory.getLogger(RouterRegister.class); + + private final String ROUTER_CONFIG_FILE = "components/router.properties"; + + private ConcurrentMap routerServiceMap = new ConcurrentHashMap<>(); + + public RouterService getRouterService(String key){ + return routerServiceMap.get(key); + } + + @Override + public void init() { + String configDir= MetaInfo.PROPERTY_CONFIG_DIR; + String fileName = configDir+ Dict.SLASH+ROUTER_CONFIG_FILE; + File file = new File(fileName); + Properties config = new Properties(); + try (InputStream inputStream = new BufferedInputStream(new FileInputStream(file))) { + config.load(inputStream); + }catch (Exception e){ + logger.error("can not found {}",fileName); + } + config.forEach((k,v)->{ + if(v!=null){ + try { + Class genClass = Class.forName(v.toString()); + Object rawObject = genClass.getConstructor().newInstance(); + routerServiceMap.put(k.toString(),(RouterService)rawObject); + if(rawObject instanceof Lifecycle){ + ( (Lifecycle)rawObject).init(); + } + } catch (Exception e) { + logger.error("register router error {} : {}",k,v,e); + } + } + }); + } + + @Override + public void start() { + routerServiceMap.forEach((k,v)->{ + if(v instanceof Lifecycle){ + ( (Lifecycle)v).start(); + } + }); + logger.info("router register : {}",routerServiceMap); + } + + @Override + public void destroy() { + + } +} diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/router/RouterService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/router/RouterService.java new file mode 100644 index 0000000000..e420d1e002 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/router/RouterService.java @@ -0,0 +1,8 @@ +package com.osx.broker.router; + + +import com.osx.api.router.RouterInfo; + +public interface RouterService { + RouterInfo route(String srcPartyId, String srcRole, String dstPartyId, String desRole); +} diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/security/MockTokenGenerator.java b/java/osx/osx-broker/src/main/java/com/osx/broker/security/MockTokenGenerator.java new file mode 100644 index 0000000000..da51d7c108 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/security/MockTokenGenerator.java @@ -0,0 +1,13 @@ +package com.osx.broker.security; + + +import com.osx.api.context.Context; + +public class MockTokenGenerator implements TokenGenerator{ + + + @Override + public String createNewToken(Context context) { + return "mock"; + } +} diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenGenerator.java b/java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenGenerator.java new file mode 100644 index 0000000000..75ccb81f4c --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenGenerator.java @@ -0,0 +1,10 @@ +package com.osx.broker.security; + + +import com.osx.api.context.Context; + +public interface TokenGenerator { + + String createNewToken(Context context); + +} diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenGeneratorRegister.java b/java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenGeneratorRegister.java new file mode 100644 index 0000000000..5b54632e36 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenGeneratorRegister.java @@ -0,0 +1,73 @@ +package com.osx.broker.security; + +import com.osx.core.config.MetaInfo; +import com.osx.core.frame.Lifecycle; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +import java.io.*; +import java.lang.reflect.InvocationTargetException; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; + + +public class TokenGeneratorRegister implements Lifecycle { + + Logger logger = LoggerFactory.getLogger(TokenGeneratorRegister.class); + + final String DEFAULT_KEY = "default"; + + private Map tokenGeneratorMap = new ConcurrentHashMap<>(); + + + @Override + public void init() { + if(MetaInfo.PROPERTY_OPEN_TOKEN_GENERATOR){ + String configFilePath= MetaInfo.PROPERTY_TOKEN_GENERATOR_CONFIG_PATH; + File file = new File(configFilePath); + Properties config = new Properties(); + try (InputStream inputStream = new BufferedInputStream(new FileInputStream(file))) { + config.load(inputStream); + }catch (Exception e){ + + } + config.forEach((k,v)->{ + if(v!=null){ + try { + Class genClass = Class.forName(v.toString()); + Object rawObject = genClass.getConstructor().newInstance(); + if(!(rawObject instanceof TokenGenerator)){ + logger.error("create token generator err , {} ",v); + return ; + } + tokenGeneratorMap.put(k.toString(),(TokenGenerator)rawObject); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } catch (InvocationTargetException e) { + throw new RuntimeException(e); + } catch (InstantiationException e) { + throw new RuntimeException(e); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + }); + } + + + } + + @Override + public void start() { + logger.info("register token generator : {}",this.tokenGeneratorMap); + } + + @Override + public void destroy() { + + } +} diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenValidator.java b/java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenValidator.java new file mode 100644 index 0000000000..9b527210e4 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenValidator.java @@ -0,0 +1,8 @@ +package com.osx.broker.security; + + +import com.osx.api.context.Context; + +public interface TokenValidator { + public void validate(Context context, String token); +} diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenValidatorRegister.java b/java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenValidatorRegister.java new file mode 100644 index 0000000000..0be7c6d0b7 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/security/TokenValidatorRegister.java @@ -0,0 +1,76 @@ +package com.osx.broker.security; + +import com.osx.core.config.MetaInfo; +import com.osx.core.constant.Dict; +import com.osx.core.frame.Lifecycle; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.InputStream; +import java.lang.reflect.InvocationTargetException; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; + +public class TokenValidatorRegister implements Lifecycle { + + Logger logger = LoggerFactory.getLogger(TokenValidatorRegister.class); + + final String DEFAULT_KEY = "default"; + final String TOKEY_VALIDATOR_CONFIG_FILE="token_validator.properties"; + + private Map tokenValidatorMap = new ConcurrentHashMap<>(); + + public TokenValidator getTokenValidator(String key,String defaultKey){ + TokenValidator result = tokenValidatorMap.get(key); + if(result ==null){ + result = tokenValidatorMap.get(defaultKey); + }; + return result; + } + @Override + public void init() { + if(MetaInfo.PROPERTY_OPEN_TOKEN_GENERATOR){ + String configDir= MetaInfo.PROPERTY_CONFIG_DIR; + String fileName = configDir+ Dict.SLASH+TOKEY_VALIDATOR_CONFIG_FILE; + File file = new File(fileName); + Properties config = new Properties(); + try (InputStream inputStream = new BufferedInputStream(new FileInputStream(file))) { + config.load(inputStream); + }catch (Exception e){ + logger.error("parse file {} error",fileName); + } + + config.forEach((k,v)->{ + if(v!=null){ + try { + Class genClass = Class.forName(v.toString()); + Object rawObject = genClass.getConstructor().newInstance(); + if(!(rawObject instanceof TokenValidator)){ + logger.error("parse token validator err , {} ",v); + return ; + } + tokenValidatorMap.put(k.toString(),(TokenValidator)rawObject); + } catch (Exception e) { + logger.error("register token validator error {} : {}",k,v); + } + } + }); + } + } + + @Override + public void start() { + logger.info("register token validator : {}",this.tokenValidatorMap); + } + + @Override + public void destroy() { + this.tokenValidatorMap.clear(); + } + + +} diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/server/OsxServer.java b/java/osx/osx-broker/src/main/java/com/osx/broker/server/OsxServer.java new file mode 100644 index 0000000000..6bded87251 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/server/OsxServer.java @@ -0,0 +1,311 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.osx.broker.server; + +import com.osx.broker.grpc.ContextPrepareInterceptor; +import com.osx.broker.grpc.PcpGrpcService; +import com.osx.broker.grpc.ProxyGrpcService; +import com.osx.broker.grpc.ServiceExceptionHandler; +import com.osx.broker.http.DispatchServlet; +import com.osx.broker.http.HttpsClientPool; +import com.osx.core.utils.OSXCertUtils; +import com.osx.core.utils.OsxX509TrustManager; +import com.osx.core.config.MetaInfo; +import io.grpc.ServerInterceptors; +import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; +import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider; +import org.apache.commons.lang3.StringUtils; +import org.eclipse.jetty.server.HttpConnectionFactory; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.server.SslConnectionFactory; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.util.ssl.SslContextFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManager; +import java.io.File; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.security.KeyStore; +import java.security.SecureRandom; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static com.osx.core.config.MetaInfo.PROPERTY_OPEN_GRPC_TLS_SERVER; + +/** + * http1.X + grpc + */ +public class OsxServer { + + Logger logger = LoggerFactory.getLogger(OsxServer.class); + io.grpc.Server server; + io.grpc.Server tlsServer; + org.eclipse.jetty.server.Server httpServer; + org.eclipse.jetty.server.Server httpsServer; + ProxyGrpcService proxyGrpcService; + PcpGrpcService pcpGrpcService; + + private synchronized void init() { + try { + proxyGrpcService = new ProxyGrpcService(); + pcpGrpcService = new PcpGrpcService(); + server = buildServer(); + if (MetaInfo.PROPERTY_OPEN_HTTP_SERVER) { + logger.info("prepare to create http server"); + httpServer = buildHttpServer(); + if (httpServer == null) { + System.exit(0); + } + if (MetaInfo.PROPERTY_HTTP_USE_TLS) { + logger.info("prepare to create http server with TLS"); + httpsServer = buildHttpsServer(); + if (httpsServer == null) { + System.exit(0); + } + } + } + tlsServer = buildTlsServer(); + }catch(Exception e){ + logger.error("server init error ",e); + e.printStackTrace(); + } + } + + public Server buildHttpServer() { + Server server = new Server(); + try { + HttpConnectionFactory http11 = new HttpConnectionFactory(); + ServerConnector connector; + connector = new ServerConnector(server, MetaInfo.PROPERTY_HTTP_SERVER_ACCEPTOR_NUM, MetaInfo.PROPERTY_HTTP_SERVER_SELECTOR_NUM, http11); + // logger.info("http server try to start listen port {}", MetaInfo.PROPERTY_HTTP_PORT); + connector.setPort(MetaInfo.PROPERTY_HTTP_PORT); + connector.setHost(MetaInfo.PROPERTY_BIND_HOST); + connector.setAcceptQueueSize(MetaInfo.PROPERTY_HTTP_RECEIVE_QUEUE_SIZE); + connector.setAcceptedReceiveBufferSize(MetaInfo.PROPERTY_HTTP_ACCEPT_RECEIVE_BUFFER_SIZE); + server.addConnector(connector); + server.setHandler(buildServlet()); + return server; + } catch (Exception e) { + logger.error("build http server error", e); + } + return null; + } + + public Server buildHttpsServer() { + Server server = new Server(); + try { + HttpConnectionFactory http11 = new HttpConnectionFactory(); + ServerConnector connector; + SslContextFactory.Server sslServer = new SslContextFactory.Server(); +// //如果PROPERTY_HTTP_SSL_TRUST_STORE_PATH 为空, 则去读取证书套件,然后生成一个TRUST_STORE + if (StringUtils.isNotBlank(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PATH)) { + sslServer.setTrustStoreType(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_TYPE.toUpperCase()); + sslServer.setKeyStorePath(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PATH); + sslServer.setTrustStore(OSXCertUtils.getTrustStore(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PATH, MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_TYPE)); + if (StringUtils.isAllBlank(MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD, MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PASSWORD)) { + throw new IllegalArgumentException("http.ssl.key.store.password/http.ssl.trust.store.password is not set,please check config file"); + } + sslServer.setTrustStorePassword(StringUtils.firstNonBlank(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PASSWORD, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD)); + sslServer.setKeyStorePassword(StringUtils.firstNonBlank(MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD, MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PASSWORD)); + sslServer.setTrustStoreProvider(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PROVIDER); + } else { + SSLContext sslContext = SSLContext.getInstance("SSL"); + KeyStore keyStore = OSXCertUtils.getKeyStore(MetaInfo.PROPERTY_SERVER_CA_FILE, MetaInfo.PROPERTY_SERVER_CERT_CHAIN_FILE, MetaInfo.PROPERTY_SERVER_PRIVATE_KEY_FILE); + TrustManager[] tm = {OsxX509TrustManager.getInstance(keyStore)}; + // Load client certificate + KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); + kmf.init(keyStore, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + sslContext.init(kmf.getKeyManagers(), tm, new SecureRandom()); + sslServer.setSslContext(sslContext); + } + sslServer.setNeedClientAuth(true); + sslServer.setSslSessionTimeout(MetaInfo.PROPERTY_HTTP_SSL_SESSION_TIME_OUT); + SslConnectionFactory tls = new SslConnectionFactory(sslServer, http11.getProtocol()); + connector = new ServerConnector(server, MetaInfo.PROPERTY_HTTP_SERVER_ACCEPTOR_NUM, MetaInfo.PROPERTY_HTTP_SERVER_SELECTOR_NUM, tls, http11); + // logger.info("http server try to start listen port {}", MetaInfo.PROPERTY_HTTP_PORT); + connector.setPort(MetaInfo.PROPERTY_HTTPS_PORT); + connector.setHost(MetaInfo.PROPERTY_BIND_HOST); + connector.setAcceptQueueSize(MetaInfo.PROPERTY_HTTP_RECEIVE_QUEUE_SIZE); + connector.setAcceptedReceiveBufferSize(MetaInfo.PROPERTY_HTTP_ACCEPT_RECEIVE_BUFFER_SIZE); + server.addConnector(connector); + server.setHandler(buildServlet()); + return server; + } catch (Exception e) { + logger.error("build https server error = {}", e.getMessage()); + e.printStackTrace(); + } + return null; + } + + ServletContextHandler buildServlet() { + ServletContextHandler context = new ServletContextHandler(); + context.setContextPath(MetaInfo.PROPERTY_HTTP_CONTEXT_PATH); + context.addServlet(DispatchServlet.class, MetaInfo.PROPERTY_HTTP_SERVLET_PATH); + context.setMaxFormContentSize(Integer.MAX_VALUE); + return context; + } + + public boolean start() { + init(); + //grpc + try { + server.start(); + logger.info("listen grpc port {} success", MetaInfo.PROPERTY_GRPC_PORT); + } catch (Exception e) { + if (e instanceof IOException || e.getCause() instanceof java.net.BindException) { + logger.error("port {} already in use, please try to choose another one !!!!", MetaInfo.PROPERTY_GRPC_PORT); + } + e.printStackTrace(); + return false; + } + + //http + try { + if (httpServer != null) { + httpServer.start(); + logger.info("listen http port {} success", MetaInfo.PROPERTY_HTTP_PORT); + } + } catch (Exception e) { + if (e instanceof java.net.BindException || e.getCause() instanceof java.net.BindException) { + logger.error("port {} already in use, please try to choose another one !!!!", MetaInfo.PROPERTY_HTTP_PORT); + } + e.printStackTrace(); + return false; + } + + //tls + try { + if (tlsServer != null) { + logger.info("grpc tls server try to start, listen port {}", MetaInfo.PROPERTY_GRPC_TLS_PORT); + tlsServer.start(); + logger.info("listen grpc tls port {} success", MetaInfo.PROPERTY_GRPC_TLS_PORT); + } + } catch (Exception e) { + if (e instanceof java.net.BindException || e.getCause() instanceof java.net.BindException) { + logger.error("port {} already in use, please try to choose another one !!!!", MetaInfo.PROPERTY_GRPC_TLS_PORT); + } + e.printStackTrace(); + return false; + } + + //https + try { + if (httpsServer != null) { + httpsServer.start(); + logger.info("listen https port {} success", MetaInfo.PROPERTY_HTTPS_PORT); + } + } catch (Exception e) { + if (e instanceof java.net.BindException || e.getCause() instanceof java.net.BindException) { + logger.error("port {} already in use, please try to choose another one !!!!", MetaInfo.PROPERTY_HTTPS_PORT); + } + e.printStackTrace(); + return false; + } + return true; + } + + private io.grpc.Server buildTlsServer() { + String certChainFilePath = MetaInfo.PROPERTY_SERVER_CERT_CHAIN_FILE; + String privateKeyFilePath = MetaInfo.PROPERTY_SERVER_PRIVATE_KEY_FILE; + String trustCertCollectionFilePath = MetaInfo.PROPERTY_SERVER_CA_FILE; + if (PROPERTY_OPEN_GRPC_TLS_SERVER && StringUtils.isNotBlank(certChainFilePath) + && StringUtils.isNotBlank(privateKeyFilePath) && StringUtils.isNotBlank(trustCertCollectionFilePath)) { + try { + SocketAddress address = new InetSocketAddress(MetaInfo.PROPERTY_BIND_HOST, MetaInfo.PROPERTY_GRPC_TLS_PORT); + NettyServerBuilder nettyServerBuilder = NettyServerBuilder.forAddress(address); + SslContextBuilder sslContextBuilder = GrpcSslContexts.forServer(new File(certChainFilePath), new File(privateKeyFilePath)) + .trustManager(new File(trustCertCollectionFilePath)) + .clientAuth(ClientAuth.REQUIRE) + .sessionTimeout(MetaInfo.PROPERTY_GRPC_SSL_SESSION_TIME_OUT) + .sessionCacheSize(MetaInfo.PROPERTY_HTTP_SSL_SESSION_CACHE_SIZE); + logger.info("running in secure mode. server crt path: {}, server key path: {}, ca crt path: {}.", + certChainFilePath, privateKeyFilePath, trustCertCollectionFilePath); + //serverBuilder.executor(executor); + nettyServerBuilder.sslContext(GrpcSslContexts.configure(sslContextBuilder, SslProvider.OPENSSL).build()); + nettyServerBuilder.addService(ServerInterceptors.intercept(proxyGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); + nettyServerBuilder.addService(ServerInterceptors.intercept(pcpGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); + + + nettyServerBuilder + .executor(Executors.newCachedThreadPool()) + .maxConcurrentCallsPerConnection(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION) + .maxInboundMessageSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE) + .maxInboundMetadataSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE) + .flowControlWindow(MetaInfo.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW); + if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC > 0) + nettyServerBuilder.keepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC > 0) + nettyServerBuilder.keepAliveTimeout(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC > 0) { + nettyServerBuilder.permitKeepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); + } + if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED) + nettyServerBuilder.permitKeepAliveWithoutCalls(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED); + if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC > 0) + nettyServerBuilder.maxConnectionIdle(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC > 0) + nettyServerBuilder.maxConnectionAge(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC > 0) + nettyServerBuilder.maxConnectionAgeGrace(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC, TimeUnit.SECONDS); + + return nettyServerBuilder.build(); + } catch (SSLException e) { + throw new SecurityException(e); + } + } + return null; + } + + + private io.grpc.Server buildServer() { + SocketAddress address = new InetSocketAddress(MetaInfo.PROPERTY_BIND_HOST, MetaInfo.PROPERTY_GRPC_PORT); + NettyServerBuilder nettyServerBuilder = NettyServerBuilder.forAddress(address); + nettyServerBuilder.addService(ServerInterceptors.intercept(proxyGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); + nettyServerBuilder.addService(ServerInterceptors.intercept(pcpGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); + nettyServerBuilder + .executor(Executors.newCachedThreadPool()) + .maxConcurrentCallsPerConnection(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION) + .maxInboundMessageSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE) + .maxInboundMetadataSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE) + .flowControlWindow(MetaInfo.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW); + if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC > 0) + nettyServerBuilder.keepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC > 0) + nettyServerBuilder.keepAliveTimeout(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC > 0) { + nettyServerBuilder.permitKeepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); + } + if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED) + nettyServerBuilder.permitKeepAliveWithoutCalls(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED); + if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC > 0) + nettyServerBuilder.maxConnectionIdle(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC > 0) + nettyServerBuilder.maxConnectionAge(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC > 0) + nettyServerBuilder.maxConnectionAgeGrace(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC, TimeUnit.SECONDS); + return nettyServerBuilder.build(); + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/service/PushService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/service/PushService.java similarity index 70% rename from java/osx/broker/src/main/java/com/osx/broker/service/PushService.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/service/PushService.java index c3a3fb645c..d9643e1955 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/service/PushService.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/service/PushService.java @@ -15,9 +15,10 @@ */ package com.osx.broker.service; -import com.osx.broker.grpc.PushRequestDataWrap; +import com.osx.broker.ServiceContainer; import com.osx.broker.grpc.QueuePushReqStreamObserver; -import com.osx.core.context.Context; +import com.osx.core.config.MetaInfo; +import com.osx.core.context.FateContext; import com.osx.core.exceptions.ExceptionInfo; import com.osx.core.exceptions.SysException; import com.osx.core.service.AbstractServiceAdaptor; @@ -27,24 +28,26 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class PushService extends AbstractServiceAdaptor { +public class PushService extends AbstractServiceAdaptor { Logger logger = LoggerFactory.getLogger(PushService.class); + + @Override - protected StreamObserver doService(Context context, InboundPackage data + protected StreamObserver doService(FateContext context, InboundPackage data ) { - PushRequestDataWrap pushRequestDataWrap = data.getBody(); - StreamObserver backRespSO = pushRequestDataWrap.getStreamObserver(); - context.setNeedPrintFlowLog(false); + StreamObserver backRespSO = data.getBody(); + // context.setNeedPrintFlowLog(false); QueuePushReqStreamObserver queuePushReqStreamObserver = new QueuePushReqStreamObserver(context, + ServiceContainer.routerRegister.getRouterService(MetaInfo.PROPERTY_FATE_TECH_PROVIDER), backRespSO, Proxy.Metadata.class); return queuePushReqStreamObserver; } @Override - protected StreamObserver transformExceptionInfo(Context context, ExceptionInfo exceptionInfo) { + protected StreamObserver transformExceptionInfo(FateContext context, ExceptionInfo exceptionInfo) { logger.error("PushService error {}", exceptionInfo); throw new SysException(exceptionInfo.toString()); } diff --git a/java/osx/broker/src/main/java/com/osx/broker/service/RegisterService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/service/RegisterService.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/service/RegisterService.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/service/RegisterService.java diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/service/RouteService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/service/RouteService.java new file mode 100644 index 0000000000..d074bd5db8 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/service/RouteService.java @@ -0,0 +1,28 @@ +package com.osx.broker.service; + +import com.osx.api.context.Context; +import com.osx.broker.router.DefaultFateRouterServiceImpl; +import com.osx.core.exceptions.ExceptionInfo; +import com.osx.core.service.AbstractServiceAdaptor; +import com.osx.core.service.InboundPackage; +import com.webank.ai.eggroll.api.networking.proxy.Proxy; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RouteService extends AbstractServiceAdaptor { + + Logger logger = LoggerFactory.getLogger(RouteService.class); + + @Override + protected Proxy.Packet doService(Context context, InboundPackage data) { + DefaultFateRouterServiceImpl defaultFateRouterService = new DefaultFateRouterServiceImpl(); + defaultFateRouterService.saveRouterTable(context, data); + Proxy.Packet.Builder resultBuilder = Proxy.Packet.newBuilder(); + return resultBuilder.build(); + } + + @Override + protected Proxy.Packet transformExceptionInfo(Context context, ExceptionInfo exceptionInfo) { + return null; + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/service/TokenApplyService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/service/TokenApplyService.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/service/TokenApplyService.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/service/TokenApplyService.java index be693393df..6cda19a898 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/service/TokenApplyService.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/service/TokenApplyService.java @@ -17,16 +17,17 @@ import com.google.protobuf.ByteString; +import com.osx.api.context.Context; +import com.osx.api.router.RouterInfo; import com.osx.broker.ServiceContainer; import com.osx.core.config.MetaInfo; import com.osx.core.constant.Dict; import com.osx.core.constant.StreamLimitMode; -import com.osx.core.context.Context; + import com.osx.core.flow.FlowRule; import com.osx.core.frame.GrpcConnectionFactory; import com.osx.core.frame.Lifecycle; import com.osx.core.ptp.TargetMethod; -import com.osx.core.router.RouterInfo; import com.osx.core.token.TokenRequest; import com.osx.core.token.TokenResult; import com.osx.core.token.TokenResultStatus; @@ -57,7 +58,7 @@ public TokenApplyService() { public PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub buildBlockingStub(String address) { String[] ipports= address.split(":"); - RouterInfo routerInfo = new RouterInfo(); + RouterInfo routerInfo = new RouterInfo(); routerInfo.setHost(ipports[0]); routerInfo.setPort(Integer.parseInt(ipports[1])); ManagedChannel channel = GrpcConnectionFactory.createManagedChannel(routerInfo,true); @@ -71,7 +72,7 @@ public void applyToken(Context context, String resource, int count) { if (MetaInfo.PROPERTY_STREAM_LIMIT_MODE.equals(StreamLimitMode.LOCAL.name()) || MetaInfo.PROPERTY_STREAM_LIMIT_MODE.equals(StreamLimitMode.CLUSTER.name())) { TokenResult localTokenResult = tryLocalLimit(resource, count); - logger.info("request token {} count {} result {}", resource, count, localTokenResult); + // logger.info("request token {} count {} result {}", resource, count, localTokenResult); /** * 集群限流 */ @@ -111,8 +112,8 @@ private TokenResult tryLocalLimit(String resource, int count) { logger.info("should wait {} ms", sleepMs); try { Thread.sleep(sleepMs); - } catch (InterruptedException e) { - e.printStackTrace(); + } catch (InterruptedException igore) { + } needLoop = false; break; diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/service/UnaryCallService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/service/UnaryCallService.java new file mode 100644 index 0000000000..5488fe6890 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/service/UnaryCallService.java @@ -0,0 +1,105 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.osx.broker.service; +import com.google.protobuf.InvalidProtocolBufferException; +import com.osx.api.router.RouterInfo; +import com.osx.broker.util.TransferUtil; +import com.osx.core.config.MetaInfo; +import com.osx.core.constant.ActionType; +import com.osx.api.constants.Protocol; +import com.osx.core.constant.StatusCode; +import com.osx.core.context.FateContext; +import com.osx.core.exceptions.ExceptionInfo; +import com.osx.core.exceptions.NoRouterInfoException; +import com.osx.core.exceptions.RemoteRpcException; +import com.osx.core.frame.GrpcConnectionFactory; +import com.osx.core.ptp.SourceMethod; +import com.osx.core.ptp.TargetMethod; +import com.osx.core.service.AbstractServiceAdaptor; +import com.osx.core.service.InboundPackage; +import com.webank.ai.eggroll.api.networking.proxy.DataTransferServiceGrpc; +import com.webank.ai.eggroll.api.networking.proxy.Proxy; +import io.grpc.ManagedChannel; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * 用于兼容旧版FATE + */ +public class UnaryCallService extends AbstractServiceAdaptor { + + Logger logger = LoggerFactory.getLogger(UnaryCallService.class); + + public UnaryCallService() { + + } + + @Override + protected Proxy.Packet doService(FateContext context, InboundPackage data) { + context.setActionType(ActionType.UNARY_CALL.getAlias()); + Proxy.Packet req = (Proxy.Packet) data.getBody(); + Proxy.Packet resp = unaryCall(context, req); + //logger.info("uncary req {} resp {}", req, resp); + return resp; + } + + + protected Proxy.Packet transformExceptionInfo(FateContext context, ExceptionInfo exceptionInfo) { + + throw new RemoteRpcException(exceptionInfo.toString()) ; + + + } + + /** + * 非流式传输 + * + * @param context + * @param + */ + public Proxy.Packet unaryCall(FateContext context, Proxy.Packet req) { + Proxy.Packet result = null; + RouterInfo routerInfo=context.getRouterInfo(); + if(routerInfo==null){ + String sourcePartyId = context.getSrcPartyId(); + String desPartyId = context.getDesPartyId(); + throw new NoRouterInfoException(sourcePartyId+" to "+desPartyId +" found no router info"); + } + if(routerInfo.getProtocol().equals(Protocol.http)){ + Osx.Inbound inbound = TransferUtil.buildInboundFromPushingPacket(req, MetaInfo.PROPERTY_FATE_TECH_PROVIDER, TargetMethod.UNARY_CALL.name(), SourceMethod.OLDUNARY_CALL.name()).build(); + Osx.Outbound outbound = TransferUtil.redirect(context,inbound,routerInfo,true); + if(outbound!=null) { + if (outbound.getCode().equals(StatusCode.SUCCESS)) { + try { + result = Proxy.Packet.parseFrom(outbound.getPayload().toByteArray()); + } catch (InvalidProtocolBufferException e) { + e.printStackTrace(); + } + } else { + throw new RemoteRpcException(outbound.getMessage()); + } + } + }else { + ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(context.getRouterInfo(), true); + DataTransferServiceGrpc.DataTransferServiceBlockingStub stub = DataTransferServiceGrpc.newBlockingStub(managedChannel); + result = stub.unaryCall(req); + } + return result; + } + + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/store/IndexQueue.java b/java/osx/osx-broker/src/main/java/com/osx/broker/store/IndexQueue.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/store/IndexQueue.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/store/IndexQueue.java index 23bb9e4ec7..bf2fca6ac7 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/store/IndexQueue.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/store/IndexQueue.java @@ -39,7 +39,7 @@ public class IndexQueue { private long maxPhysicOffset = -1; private volatile long minLogicOffset = 0; private AtomicLong logicOffset = new AtomicLong(0); - + Logger logger = LoggerFactory.getLogger(IndexQueue.class); public IndexQueue( final String transferId, final String storePath, @@ -68,9 +68,7 @@ public boolean load() { public long getLastOffset() { long lastOffset = -1; - int logicFileSize = this.mappedFileSize; - MappedFile mappedFile = this.mappedFileQueue.getLastMappedFile(); if (mappedFile != null) { @@ -143,18 +141,18 @@ public long getMinOffsetInQueue() { public long putMessagePositionInfoWrapper(long offset, int msgSize) { final int maxRetries = 30; - + long resultLogicOffset = -1; for (int i = 0; i < maxRetries; i++) { boolean result = this.putMessagePositionInfo(offset, msgSize, this.logicOffset.get() + 1); if (result) { - return logicOffset.addAndGet(1); - + resultLogicOffset = logicOffset.addAndGet(1); + return resultLogicOffset; } } - return -1; + return resultLogicOffset; } @@ -181,8 +179,8 @@ private boolean putMessagePositionInfo(final long offset, final int size, this.mappedFileQueue.setFlushedWhere(expectLogicOffset); this.mappedFileQueue.setCommittedWhere(expectLogicOffset); this.fillPreBlank(mappedFile, expectLogicOffset); - log.info("fill pre blank space " + mappedFile.getFileName() + " " + expectLogicOffset + " " - + mappedFile.getWrotePosition()); +// log.info("fill pre blank space " + mappedFile.getFileName() + " " + expectLogicOffset + " " +// + mappedFile.getWrotePosition()); } if (cqOffset != 0) { diff --git a/java/osx/broker/src/main/java/com/osx/broker/store/MessageStore.java b/java/osx/osx-broker/src/main/java/com/osx/broker/store/MessageStore.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/store/MessageStore.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/store/MessageStore.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/token/DefaultTokenService.java b/java/osx/osx-broker/src/main/java/com/osx/broker/token/DefaultTokenService.java similarity index 94% rename from java/osx/broker/src/main/java/com/osx/broker/token/DefaultTokenService.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/token/DefaultTokenService.java index 1b54594650..49aa299129 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/token/DefaultTokenService.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/token/DefaultTokenService.java @@ -16,7 +16,6 @@ package com.osx.broker.token; import com.osx.core.config.MetaInfo; import com.osx.core.constant.Dict; -import com.osx.core.context.Context; import com.osx.core.exceptions.ExceptionInfo; import com.osx.core.flow.*; import com.osx.core.service.AbstractServiceAdaptor; @@ -46,10 +45,10 @@ public TokenResult requestToken(String resource, int acquireCount, boolean prior } FlowRule rule = ClusterFlowRuleManager.getFlowRuleByResource(resource); if (rule == null) { - logger.error("resource {} no rule", resource); + //logger.error("resource {} no rule", resource); ClusterMetric clusterMetric = ClusterMetricStatistics.getMetric(resource); if (clusterMetric == null) { - ClusterMetricStatistics.putMetricIfAbsent(resource, new ClusterMetric(MetaInfo.PROPERTY_SAMPLE_COUNT, MetaInfo.PROPERTY_INTERVAL_MS)); + ClusterMetricStatistics.putMetricIfAbsent(resource, new ClusterMetric(MetaInfo.PROPERTY_FLOW_CONTROL_SAMPLE_COUNT, MetaInfo.PROPERTY_FLOW_CONTROL_SAMPLE_INTERVAL)); clusterMetric = ClusterMetricStatistics.getMetric(resource); } clusterMetric.add(ClusterFlowEvent.PASS, acquireCount); diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/ContextUtil.java b/java/osx/osx-broker/src/main/java/com/osx/broker/util/ContextUtil.java similarity index 79% rename from java/osx/broker/src/main/java/com/osx/broker/util/ContextUtil.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/util/ContextUtil.java index fff76b41b5..cd1d9fa145 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/util/ContextUtil.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/util/ContextUtil.java @@ -16,16 +16,18 @@ package com.osx.broker.util; import com.osx.broker.grpc.ContextPrepareInterceptor; -import com.osx.core.context.Context; - -import java.util.UUID; +import com.osx.api.constants.Protocol; +import com.osx.core.context.FateContext; public class ContextUtil { - public static Context buildContext() { - Context context = new Context(); + public static FateContext buildFateContext(Protocol protocol) { + FateContext context = new FateContext(); + context.setProtocol(protocol); context.setSourceIp(ContextPrepareInterceptor.sourceIp.get() != null ? ContextPrepareInterceptor.sourceIp.get().toString() : ""); - context.setCaseId(UUID.randomUUID().toString()); + return context; } + + } diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/DateUtils.java b/java/osx/osx-broker/src/main/java/com/osx/broker/util/DateUtils.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/util/DateUtils.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/util/DateUtils.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/LibC.java b/java/osx/osx-broker/src/main/java/com/osx/broker/util/LibC.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/util/LibC.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/util/LibC.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/MessageConst.java b/java/osx/osx-broker/src/main/java/com/osx/broker/util/MessageConst.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/util/MessageConst.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/util/MessageConst.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/MessageId.java b/java/osx/osx-broker/src/main/java/com/osx/broker/util/MessageId.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/util/MessageId.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/util/MessageId.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/ResourceUtil.java b/java/osx/osx-broker/src/main/java/com/osx/broker/util/ResourceUtil.java similarity index 96% rename from java/osx/broker/src/main/java/com/osx/broker/util/ResourceUtil.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/util/ResourceUtil.java index a55489f0ac..8272ec8fb8 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/util/ResourceUtil.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/util/ResourceUtil.java @@ -16,8 +16,9 @@ package com.osx.broker.util; +import com.osx.api.router.RouterInfo; import com.osx.broker.constants.Direction; -import com.osx.core.router.RouterInfo; + public class ResourceUtil { diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/util/TelnetUtil.java b/java/osx/osx-broker/src/main/java/com/osx/broker/util/TelnetUtil.java new file mode 100644 index 0000000000..f9d0b784eb --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/util/TelnetUtil.java @@ -0,0 +1,21 @@ +package com.osx.broker.util; + +import org.apache.commons.net.telnet.TelnetClient; + +public class TelnetUtil { + + public static boolean tryTelnet(String host ,int port){ + TelnetClient telnetClient = new TelnetClient("vt200"); + telnetClient.setDefaultTimeout(5000); + boolean isConnected = false; + try { + telnetClient.connect(host, port); + isConnected = true; + telnetClient.disconnect(); + } catch (Exception e) { + //e.printStackTrace(); + } + return isConnected; + } + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/TimeUtils.java b/java/osx/osx-broker/src/main/java/com/osx/broker/util/TimeUtils.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/util/TimeUtils.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/util/TimeUtils.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/TransferExceptionUtil.java b/java/osx/osx-broker/src/main/java/com/osx/broker/util/TransferExceptionUtil.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/util/TransferExceptionUtil.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/util/TransferExceptionUtil.java diff --git a/java/osx/osx-broker/src/main/java/com/osx/broker/util/TransferUtil.java b/java/osx/osx-broker/src/main/java/com/osx/broker/util/TransferUtil.java new file mode 100644 index 0000000000..8c0ca5a5e2 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/util/TransferUtil.java @@ -0,0 +1,550 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.osx.broker.util; + + +import com.google.common.collect.Maps; +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import com.osx.api.constants.Protocol; +import com.osx.api.context.Context; +import com.osx.api.router.RouterInfo; +import com.osx.broker.constants.MessageFlag; +import com.osx.broker.http.HttpClientPool; +import com.osx.broker.http.HttpsClientPool; +import com.osx.broker.queue.TransferQueue; +import com.osx.core.config.MetaInfo; +import com.osx.core.config.TransferMeta; +import com.osx.core.constant.Dict; +import com.osx.core.constant.PtpHttpHeader; +import com.osx.core.constant.Role; +import com.osx.core.constant.StatusCode; +import com.osx.core.context.FateContext; +import com.osx.core.exceptions.*; +import com.osx.core.frame.GrpcConnectionFactory; +import com.osx.core.ptp.SourceMethod; +import com.osx.core.utils.AssertUtil; +import com.webank.ai.eggroll.api.networking.proxy.DataTransferServiceGrpc; +import com.webank.ai.eggroll.api.networking.proxy.Proxy; +import com.webank.eggroll.core.transfer.Transfer; +import io.grpc.ManagedChannel; +import io.grpc.StatusRuntimeException; +import org.apache.commons.lang3.StringUtils; +import org.ppc.ptp.Osx; +import org.ppc.ptp.PrivateTransferProtocolGrpc; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.io.OutputStream; +import java.util.Map; + +public class TransferUtil { + + + static Logger logger = LoggerFactory.getLogger(TransferUtil.class); + + + /** + * 2.0之前版本 + * + * @param version + * @return + */ + public static boolean isOldVersionFate(String version) { + + try { + if (StringUtils.isEmpty(version)) + version = MetaInfo.PROPERTY_DEFAULT_CLIENT_VERSION; + String firstVersion = version.substring(0, 1); + if (Integer.parseInt(firstVersion) >= 2) { + return false; + } else { + return true; + } + } catch (NumberFormatException e) { + throw new ConfigErrorException("remote version config error : " + version); + } + + } + + + public static String buildResource(Osx.Inbound inbound) { + String sourceNodeId = inbound.getMetadataMap().get(Osx.Header.SourceNodeID.name()); + String targetNodeId = inbound.getMetadataMap().get(Osx.Header.TargetNodeID.name()); + String sourceInstId = inbound.getMetadataMap().get(Osx.Header.SourceInstID.name()); + if (sourceInstId == null) { + sourceInstId = ""; + } + String targetInstId = inbound.getMetadataMap().get(Osx.Header.TargetInstID.name()); + if (targetInstId == null) { + targetInstId = ""; + } + StringBuffer sb = new StringBuffer(); + sb.append(sourceInstId).append(sourceNodeId).append("_").append(targetInstId).append(targetNodeId); + return sb.toString(); + } + + public static Proxy.Metadata buildProxyMetadataFromOutbound(Osx.Outbound outbound) { + try { + return Proxy.Metadata.parseFrom(outbound.getPayload()); + } catch (InvalidProtocolBufferException e) { + + } + return null; + } + + public static Osx.Outbound buildOutboundFromProxyMetadata(Proxy.Metadata metadata) { + return Osx.Outbound.newBuilder().setPayload(metadata.toByteString()).build(); + + } + + public static Proxy.Packet parsePacketFromInbound(Osx.Inbound inbound) { + try { + return Proxy.Packet.parseFrom(inbound.getPayload()); + } catch (InvalidProtocolBufferException e) { + return null; + } + } + + public static Osx.Inbound.Builder buildInbound(String provider, + String srcPartyId, + String desPartyId, + String targetMethod, + String topic, + MessageFlag messageFlag, + String sessionId, + byte[] payLoad) { + + Osx.Inbound.Builder inboundBuilder = Osx.Inbound.newBuilder(); + inboundBuilder.putMetadata(Osx.Header.Version.name(), MetaInfo.CURRENT_VERSION); + inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), provider); +// inboundBuilder.putMetadata(Osx.Header.Token.name(), ""); + inboundBuilder.putMetadata(Osx.Header.SourceNodeID.name(), srcPartyId); + inboundBuilder.putMetadata(Osx.Header.TargetNodeID.name(), desPartyId); +// inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), ""); +// inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), ""); + if (StringUtils.isNotEmpty(sessionId)) { + inboundBuilder.putMetadata(Osx.Header.SessionID.name(), sessionId); + } + inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), targetMethod); +// inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), ""); +// inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), ""); + if (StringUtils.isNotEmpty(topic)) { + inboundBuilder.putMetadata(Osx.Metadata.MessageTopic.name(), topic); + } + if (messageFlag != null) { + inboundBuilder.putMetadata(Osx.Metadata.MessageFlag.name(), messageFlag.name()); + } + if (payLoad != null) { + inboundBuilder.setPayload(ByteString.copyFrom(payLoad)); + } + return inboundBuilder; + + } + + + public static TransferMeta parseTransferMetaFromProxyPacket(Proxy.Packet packet) { + TransferMeta transferMeta = new TransferMeta(); + Proxy.Metadata metadata = packet.getHeader(); + Transfer.RollSiteHeader rollSiteHeader = null; + String dstPartyId = null; + String srcPartyId = null; + String desRole = null; + String srcRole = null; + try { + rollSiteHeader = Transfer.RollSiteHeader.parseFrom(metadata.getExt()); + } catch (InvalidProtocolBufferException e) { + throw new ParameterException("invalid rollSiteHeader"); + } + String sessionId = ""; + if (rollSiteHeader != null) { + dstPartyId = rollSiteHeader.getDstPartyId(); + srcPartyId = rollSiteHeader.getSrcPartyId(); + desRole = rollSiteHeader.getDstRole(); + srcRole = rollSiteHeader.getSrcRole(); + } + if (StringUtils.isEmpty(dstPartyId)) { + dstPartyId = metadata.getDst().getPartyId(); + } + if (StringUtils.isEmpty(desRole)) { + desRole = metadata.getDst().getRole(); + } + if (StringUtils.isEmpty(srcRole)) { + srcRole = metadata.getSrc().getRole(); + } + if (StringUtils.isEmpty(srcPartyId)) { + srcPartyId = metadata.getSrc().getPartyId(); + } + + if (rollSiteHeader != null) { + sessionId = String.join("_", rollSiteHeader.getRollSiteSessionId(), desRole, dstPartyId); + } + if(metadata.getDst()!=null){ + transferMeta.setTopic(metadata.getDst().getName()); + } + + transferMeta.setDesPartyId(dstPartyId); + transferMeta.setSrcPartyId(srcPartyId); + transferMeta.setDesRole(desRole); + transferMeta.setSrcRole(srcRole); + transferMeta.setSessionId(sessionId); + return transferMeta; + } + + public static void assableContextFromInbound(Context context, Osx.Inbound request) { + Map metaDataMap = request.getMetadataMap(); + String version = metaDataMap.get(Osx.Header.Version.name()); + String jobId = metaDataMap.get(Osx.Metadata.JobId.name()); + String techProviderCode = metaDataMap.get(Osx.Header.TechProviderCode.name()); + String traceId = metaDataMap.get(Osx.Header.TraceID.name()); + String token = metaDataMap.get(Osx.Header.Token.name()); + String sourceNodeId = metaDataMap.get(Osx.Header.SourceNodeID.name()); + String targetNodeId = metaDataMap.get(Osx.Header.TargetNodeID.name()); + String sourceInstId = metaDataMap.get(Osx.Header.SourceInstID.name()); + String targetInstId = metaDataMap.get(Osx.Header.TargetInstID.name()); + String sessionId = metaDataMap.get(Osx.Header.SessionID.name()); + String targetMethod = metaDataMap.get(Osx.Metadata.TargetMethod.name()); + String targetComponentName = metaDataMap.get(Osx.Metadata.TargetComponentName.name()); + String sourceComponentName = metaDataMap.get(Osx.Metadata.SourceComponentName.name()); + String sourcePartyId = StringUtils.isEmpty(sourceInstId) ? sourceNodeId : sourceInstId + "." + sourceNodeId; + String targetPartyId = StringUtils.isEmpty(targetInstId) ? targetNodeId : targetInstId + "." + targetNodeId; + String topic = metaDataMap.get(Osx.Metadata.MessageTopic.name()); + String offsetString = metaDataMap.get(Osx.Metadata.MessageOffSet.name()); + String messageCode = metaDataMap.get(Osx.Metadata.MessageCode.name()); + Long offset = StringUtils.isNotEmpty(offsetString) ? Long.parseLong(offsetString) : null; + context.setTraceId(traceId); + context.setToken(token); + context.setDesPartyId(targetPartyId); + context.setSrcPartyId(sourcePartyId); + context.setTopic(topic); + context.setJobId(jobId); + + + + if (context instanceof FateContext) { + ((FateContext) context).setRequestMsgIndex(offset); + ((FateContext) context).setMessageCode(messageCode); + } + context.setSessionId(sessionId); + context.setDesComponent(targetComponentName); + context.setSrcComponent(sourceComponentName); + context.setTechProviderCode(techProviderCode); + if (MetaInfo.PROPERTY_SELF_PARTY.contains(context.getDesPartyId())) { + context.setSelfPartyId(context.getDesPartyId()); + } else { + context.setSelfPartyId(MetaInfo.PROPERTY_SELF_PARTY.toArray()[0].toString()); + } + } + + public static void assableContextFromProxyPacket(Context context, Proxy.Packet packet) { + TransferMeta transferMeta = parseTransferMetaFromProxyPacket(packet); + context.setSrcPartyId(transferMeta.getSrcPartyId()); + context.setDesPartyId(transferMeta.getDesPartyId()); + context.setSrcComponent(transferMeta.getSrcRole()); + context.setDesComponent(transferMeta.getDesRole()); + context.setSessionId(transferMeta.getSessionId()); + context.setTopic(transferMeta.getTopic()); + context.setTechProviderCode(MetaInfo.PROPERTY_FATE_TECH_PROVIDER); + if (MetaInfo.PROPERTY_SELF_PARTY.contains(context.getDesPartyId())) { + context.setSelfPartyId(context.getDesPartyId()); + } else { + context.setSelfPartyId(MetaInfo.PROPERTY_SELF_PARTY.toArray()[0].toString()); + } + + } + + + public static Osx.Inbound.Builder buildInboundFromPushingPacket(Proxy.Packet packet, String provider, String targetMethod, String sourceMethod) { + Osx.Inbound.Builder inboundBuilder = Osx.Inbound.newBuilder(); + TransferMeta transferMeta = parseTransferMetaFromProxyPacket(packet); + inboundBuilder.setPayload(packet.toByteString()); + inboundBuilder.putMetadata(Osx.Header.Version.name(), MetaInfo.CURRENT_VERSION); + inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), provider); + inboundBuilder.putMetadata(Osx.Header.Token.name(), ""); + inboundBuilder.putMetadata(Osx.Header.SourceNodeID.name(), transferMeta.getSrcPartyId()); + inboundBuilder.putMetadata(Osx.Header.TargetNodeID.name(), transferMeta.getDesPartyId()); + inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), ""); + inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), ""); + inboundBuilder.putMetadata(Osx.Metadata.SourceMethod.name(), sourceMethod); + inboundBuilder.putMetadata(Osx.Header.SessionID.name(), transferMeta.getSessionId()); + inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), targetMethod); + inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), transferMeta.getDesRole()); + inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), ""); + return inboundBuilder; + + } + + ; + + + static public Osx.Inbound.Builder buildPbFromHttpRequest(Context context, HttpServletRequest request) { + + Osx.Inbound.Builder inboundBuilder = Osx.Inbound.newBuilder(); + String version = request.getHeader(PtpHttpHeader.Version); + String techProviderCode = request.getHeader(PtpHttpHeader.TechProviderCode); + String traceID = request.getHeader(PtpHttpHeader.TraceID); + String token = request.getHeader(PtpHttpHeader.Token); + String sourceNodeID = request.getHeader(PtpHttpHeader.SourceNodeID); + String targetNodeID = request.getHeader(PtpHttpHeader.TargetNodeID); + String sourceInstID = request.getHeader(PtpHttpHeader.SourceInstID); + String targetInstID = request.getHeader(PtpHttpHeader.TargetInstID); + String sessionID = request.getHeader(PtpHttpHeader.SessionID); + String messageTopic = request.getHeader(PtpHttpHeader.MessageTopic); + String messageCode = request.getHeader(Osx.Metadata.MessageCode.name()); + String retryCount = request.getHeader(Osx.Metadata.RetryCount.name()); + String sourceComponentName = request.getHeader(PtpHttpHeader.SourceComponentName); + String targetComponentName = request.getHeader(PtpHttpHeader.TargetComponentName); + String targetMethod = request.getHeader(PtpHttpHeader.TargetMethod); + String sourceMethod = request.getHeader(PtpHttpHeader.SourceMethod); + String messageOffSet = request.getHeader(PtpHttpHeader.MessageOffSet); + String instanceId = request.getHeader(PtpHttpHeader.InstanceId); + String timestamp = request.getHeader(PtpHttpHeader.Timestamp); + String messageFlag = request.getHeader(PtpHttpHeader.MessageFlag); + String jobId = request.getHeader(PtpHttpHeader.JobId); + context.setSrcPartyId(sourceNodeID); + context.setDesPartyId(targetNodeID); + context.setSessionId(sessionID); + context.setTopic(messageTopic); + context.setActionType(targetMethod); + context.setProtocol(Protocol.http); + inboundBuilder.putMetadata(Osx.Header.Version.name(), version != null ? version : ""); + inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), techProviderCode != null ? techProviderCode : ""); + inboundBuilder.putMetadata(Osx.Header.Token.name(), token != null ? token : ""); + inboundBuilder.putMetadata(Osx.Header.SourceNodeID.name(), sourceNodeID != null ? sourceNodeID : ""); + inboundBuilder.putMetadata(Osx.Header.TargetNodeID.name(), targetNodeID != null ? targetNodeID : ""); + inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), sourceInstID != null ? sourceInstID : ""); + inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), targetInstID != null ? targetInstID : ""); + inboundBuilder.putMetadata(Osx.Header.SessionID.name(), sessionID != null ? sessionID : ""); + inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), targetMethod != null ? targetMethod : ""); + inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), targetComponentName != null ? targetComponentName : ""); + inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), sourceComponentName != null ? sourceComponentName : ""); + inboundBuilder.putMetadata(Osx.Metadata.MessageTopic.name(), messageTopic != null ? messageTopic : ""); + inboundBuilder.putMetadata(Osx.Metadata.MessageOffSet.name(), messageOffSet != null ? messageOffSet : ""); + inboundBuilder.putMetadata(Osx.Metadata.InstanceId.name(), instanceId != null ? instanceId : ""); + inboundBuilder.putMetadata(Osx.Metadata.Timestamp.name(), timestamp != null ? timestamp : ""); + inboundBuilder.putMetadata(Osx.Metadata.SourceMethod.name(), sourceMethod != null ? sourceMethod : ""); + inboundBuilder.putMetadata(Osx.Metadata.MessageFlag.name(), messageFlag != null ? messageFlag : ""); + inboundBuilder.putMetadata(Osx.Metadata.JobId.name(), jobId != null ? jobId : ""); + inboundBuilder.putMetadata(Osx.Metadata.MessageCode.name(), messageCode != null ? messageCode : ""); + inboundBuilder.putMetadata(Osx.Metadata.RetryCount.name(), retryCount != null ? retryCount : ""); + return inboundBuilder; + } + + + static public Map parseHttpHeader(Osx.Inbound produceRequest) { + Map metaDataMap = produceRequest.getMetadataMap(); + String version = metaDataMap.get(Osx.Header.Version.name()); + String techProviderCode = metaDataMap.get(Osx.Header.TechProviderCode.name()); + String traceId = metaDataMap.get(Osx.Header.TraceID.name()); + String token = metaDataMap.get(Osx.Header.Token.name()); + String sourceNodeId = metaDataMap.get(Osx.Header.SourceNodeID.name()); + String targetNodeId = metaDataMap.get(Osx.Header.TargetNodeID.name()); + String sourceInstId = metaDataMap.get(Osx.Header.SourceInstID.name()); + String targetInstId = metaDataMap.get(Osx.Header.TargetInstID.name()); + String sessionId = metaDataMap.get(Osx.Header.SessionID.name()); + String targetMethod = metaDataMap.get(Osx.Metadata.TargetMethod.name()); + String sourceMethod = metaDataMap.get(Osx.Metadata.SourceMethod.name()); + String targetComponentName = metaDataMap.get(Osx.Metadata.TargetComponentName.name()); + String sourceComponentName = metaDataMap.get(Osx.Metadata.SourceComponentName.name()); + String sourcePartyId = StringUtils.isEmpty(sourceInstId) ? sourceNodeId : sourceInstId + "." + sourceNodeId; + String targetPartyId = StringUtils.isEmpty(targetInstId) ? targetNodeId : targetInstId + "." + targetNodeId; + String topic = metaDataMap.get(Osx.Metadata.MessageTopic.name()); + String offsetString = metaDataMap.get(Osx.Metadata.MessageOffSet.name()); + String InstanceId = metaDataMap.get(Osx.Metadata.InstanceId.name()); + String timestamp = metaDataMap.get(Osx.Metadata.Timestamp.name()); + String messageCode = metaDataMap.get(Osx.Metadata.MessageCode.name()); + String messageFlag = metaDataMap.get(Osx.Metadata.MessageFlag.name()); + String jobId = metaDataMap.get(Osx.Metadata.JobId.name()); + + Map header = Maps.newHashMap(); + header.put(PtpHttpHeader.Version, version != null ? version : ""); + header.put(PtpHttpHeader.TechProviderCode, techProviderCode != null ? techProviderCode : ""); + header.put(PtpHttpHeader.TraceID, traceId != null ? traceId : ""); + header.put(PtpHttpHeader.Token, token != null ? token : ""); + header.put(PtpHttpHeader.SourceNodeID, sourceNodeId != null ? sourceNodeId : ""); + header.put(PtpHttpHeader.TargetNodeID, targetNodeId != null ? targetNodeId : ""); + header.put(PtpHttpHeader.SourceInstID, sourceInstId != null ? sourceInstId : ""); + header.put(PtpHttpHeader.TargetInstID, targetInstId != null ? targetInstId : ""); + header.put(PtpHttpHeader.SessionID, sessionId != null ? sessionId : ""); + header.put(PtpHttpHeader.MessageTopic, topic != null ? topic : ""); + header.put(PtpHttpHeader.MessageCode, messageCode); + header.put(PtpHttpHeader.SourceComponentName, sourceComponentName != null ? sourceComponentName : ""); + header.put(PtpHttpHeader.TargetComponentName, targetComponentName != null ? targetComponentName : ""); + header.put(PtpHttpHeader.TargetMethod, targetMethod != null ? targetMethod : ""); + header.put(PtpHttpHeader.SourceMethod, sourceMethod != null ? sourceMethod : ""); + header.put(PtpHttpHeader.MessageOffSet, offsetString != null ? offsetString : ""); + header.put(PtpHttpHeader.InstanceId, InstanceId != null ? InstanceId : ""); + header.put(PtpHttpHeader.Timestamp, timestamp != null ? timestamp : ""); + header.put(PtpHttpHeader.MessageFlag, messageFlag != null ? messageFlag : ""); + header.put(PtpHttpHeader.JobId, jobId != null ? jobId : ""); + + return header; + } + + static public Osx.Outbound redirect(FateContext context, Osx.Inbound + produceRequest, RouterInfo routerInfo,boolean usePooled) { + AssertUtil.notNull(routerInfo, context.getDesPartyId()!=null?"des partyId "+context.getDesPartyId()+" router info is null":" error router info"); + Osx.Outbound result = null; + context.setDataSize(produceRequest.getSerializedSize()); + if (routerInfo.isCycle()) { + throw new CycleRouteInfoException("cycle router info"); + } + if (routerInfo.getProtocol() == null || routerInfo.getProtocol().equals(Protocol.grpc)) { + //来自旧版fateflow的请求,需要用旧版的stub + if (context.isDestination() && Role.fateflow.name().equals(routerInfo.getDesRole()) + && SourceMethod.OLDUNARY_CALL.name().equals(produceRequest.getMetadataMap().get(Osx.Metadata.SourceMethod.name()))) { + ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(routerInfo, usePooled); + DataTransferServiceGrpc.DataTransferServiceBlockingStub stub = DataTransferServiceGrpc.newBlockingStub(managedChannel); + Proxy.Packet request; + try { + request = Proxy.Packet.parseFrom(produceRequest.getPayload().toByteArray()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + Proxy.Packet response = stub.unaryCall(request); + result = Osx.Outbound.newBuilder().setPayload(response.toByteString()).setCode(StatusCode.SUCCESS).build(); + } else { + + PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub stub = null; + if (context.getData(Dict.BLOCKING_STUB) == null) { + ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(routerInfo, usePooled); + stub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); + } else { + stub = (PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub) context.getData(Dict.BLOCKING_STUB); + } + try { + result = stub.invoke(produceRequest); + } catch (StatusRuntimeException e) { + logger.error("redirect error", e); + throw new RemoteRpcException(StatusCode.NET_ERROR, "send to " + routerInfo.toKey() + " error : " + e.getMessage()); + } + } + // ServiceContainer.tokenApplyService.applyToken(context,routerInfo.getResource(),produceRequest.getSerializedSize()); + } else { + String url = routerInfo.getUrl(); + Map header = parseHttpHeader(produceRequest); + try { + if (routerInfo.getProtocol().equals(Protocol.http)) { + if (routerInfo.isUseSSL()) { + result = HttpsClientPool.sendPtpPost(url, produceRequest.getPayload().toByteArray(), header, routerInfo.getCaFile(), routerInfo.getCertChainFile(), routerInfo.getPrivateKeyFile()); + } else { + + result = HttpClientPool.sendPtpPost(url, produceRequest.getPayload().toByteArray(), header); + } + } + } catch (Exception e) { + e.printStackTrace(); + logger.error("sendPtpPost failed : ", e); + ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context, e); + result = Osx.Outbound.newBuilder().setCode(exceptionInfo.getCode()).setMessage(exceptionInfo.getMessage()).build(); + } + } + return result; + } + + + public static Osx.Outbound.Builder buildResponseInner(String code, String msgReturn, byte[] content) { + + Osx.Outbound.Builder builder = Osx.Outbound.newBuilder(); + builder.setCode(code); + builder.setMessage(msgReturn); + if(content!=null) { + builder.setPayload(ByteString.copyFrom(content)); + } + return builder; + } + + + + + + + + public static Osx.Outbound buildResponse(String code, String msgReturn, TransferQueue.TransferQueueConsumeResult messageWraper) { + + byte[] content = null; + if (messageWraper != null) { + Osx.Message message = null; + try { + message = Osx.Message.parseFrom(messageWraper.getMessage().getBody()); + } catch (InvalidProtocolBufferException e) { + logger.error("parse message error",e); + } + content = message.toByteArray(); + } + Osx.Outbound.Builder builder =buildResponseInner(code,msgReturn,content); + if(messageWraper!=null){ + builder.putMetadata(Osx.Metadata.MessageOffSet.name(), Long.toString(messageWraper.getRequestIndex())); + } + return builder.build(); + } + + public static void checkResponse(Osx.Outbound outbound) { + if (outbound != null) { + String code = outbound.getCode(); + String message = outbound.getMessage(); + if (!StatusCode.SUCCESS.equals(code)) { + logger.error("================== xxxxxx {}",outbound); + throw new RemoteRpcException("remote code : " + code + " remote msg: " + message); + } + } else { + throw new RemoteRpcException("has no response"); + } + } + + public static void writeHttpRespose(HttpServletResponse response, String code, + String msg, + byte[] content) { + try { + response.setHeader(PtpHttpHeader.ReturnCode, code); + response.setHeader(PtpHttpHeader.MessageCode, msg); + OutputStream outputStream = response.getOutputStream(); + if (content != null) { + outputStream.write(content); + } + outputStream.flush(); + } catch (IOException e) { + logger.error("write http response error", e); + } + } + + + public static void main(String[] args) { + TransferUtil a = new TransferUtil(); + a.testHttps(); + } + + public void testHttps(){ + try { + new Thread(()->{ + Osx.Outbound outbound = null; + try { + Thread.sleep(3000); + outbound = HttpsClientPool.sendPtpPost("https://127.0.0.1:8088/osx/inbound", new byte[10], null, "D:\\22\\ca.crt", "D:\\22\\174_2.crt", "D:\\22\\174_2.key"); + } catch (Exception e) { + e.printStackTrace(); + } + System.out.println("outbound = " + outbound); + + }).start(); + } catch (Exception e) { + e.printStackTrace(); + } + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/UtilAll.java b/java/osx/osx-broker/src/main/java/com/osx/broker/util/UtilAll.java similarity index 99% rename from java/osx/broker/src/main/java/com/osx/broker/util/UtilAll.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/util/UtilAll.java index ac00f84f50..7273932a37 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/util/UtilAll.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/util/UtilAll.java @@ -421,9 +421,6 @@ public static boolean isInternalIP(byte[] ip) { throw new RuntimeException("illegal ipv4 bytes"); } - //10.0.0.0~10.255.255.255 - //172.16.0.0~172.31.255.255 - //192.168.0.0~192.168.255.255 if (ip[0] == (byte) 10) { return true; diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/AbstractZookeeperClient.java b/java/osx/osx-broker/src/main/java/com/osx/broker/zk/AbstractZookeeperClient.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/zk/AbstractZookeeperClient.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/zk/AbstractZookeeperClient.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/ChildListener.java b/java/osx/osx-broker/src/main/java/com/osx/broker/zk/ChildListener.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/zk/ChildListener.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/zk/ChildListener.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/CuratorZookeeperClient.java b/java/osx/osx-broker/src/main/java/com/osx/broker/zk/CuratorZookeeperClient.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/zk/CuratorZookeeperClient.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/zk/CuratorZookeeperClient.java index a9368bbe02..538a0ea98b 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/zk/CuratorZookeeperClient.java +++ b/java/osx/osx-broker/src/main/java/com/osx/broker/zk/CuratorZookeeperClient.java @@ -199,14 +199,9 @@ public void createEphemeral(String path, String data) throws NodeExistsException public void delete(String path) { try { if (aclEnable) { -// Stat stat = client.checkExists().forPath(path); -// client.delete().withVersion(stat.getAversion()).forPath(path); this.clearAcl(path); } - logger.info("xxxxxxxxxxxxx"); - client.delete().forPath(path); - logger.info("pppppppppppppppp"); } catch (NoNodeException e) { e.printStackTrace(); } catch (Exception e) { diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/DataListener.java b/java/osx/osx-broker/src/main/java/com/osx/broker/zk/DataListener.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/zk/DataListener.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/zk/DataListener.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/EventType.java b/java/osx/osx-broker/src/main/java/com/osx/broker/zk/EventType.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/zk/EventType.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/zk/EventType.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/StateListener.java b/java/osx/osx-broker/src/main/java/com/osx/broker/zk/StateListener.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/zk/StateListener.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/zk/StateListener.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/ZkConfig.java b/java/osx/osx-broker/src/main/java/com/osx/broker/zk/ZkConfig.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/zk/ZkConfig.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/zk/ZkConfig.java diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/ZookeeperClient.java b/java/osx/osx-broker/src/main/java/com/osx/broker/zk/ZookeeperClient.java similarity index 100% rename from java/osx/broker/src/main/java/com/osx/broker/zk/ZookeeperClient.java rename to java/osx/osx-broker/src/main/java/com/osx/broker/zk/ZookeeperClient.java diff --git a/java/osx/broker/src/main/java/com/osx/tech/provider/FateTechProvider.java b/java/osx/osx-broker/src/main/java/com/osx/tech/provider/FateTechProvider.java similarity index 53% rename from java/osx/broker/src/main/java/com/osx/tech/provider/FateTechProvider.java rename to java/osx/osx-broker/src/main/java/com/osx/tech/provider/FateTechProvider.java index d63286586f..147655c502 100644 --- a/java/osx/broker/src/main/java/com/osx/tech/provider/FateTechProvider.java +++ b/java/osx/osx-broker/src/main/java/com/osx/tech/provider/FateTechProvider.java @@ -16,31 +16,36 @@ package com.osx.tech.provider; -import com.google.common.base.Preconditions; import com.google.common.collect.Sets; import com.google.protobuf.ByteString; -import com.osx.broker.ServiceContainer; +import com.osx.api.context.Context; +import com.osx.broker.interceptor.PcpHandleInterceptor; +import com.osx.broker.interceptor.TokenValidatorInterceptor; +import com.osx.broker.router.RouterRegister; import com.osx.broker.util.ContextUtil; -import com.osx.broker.interceptor.RequestHandleInterceptor; import com.osx.broker.interceptor.RouterInterceptor; import com.osx.broker.ptp.*; import com.osx.broker.util.TransferUtil; import com.osx.core.config.MetaInfo; import com.osx.core.constant.Dict; +import com.osx.api.constants.Protocol; import com.osx.core.constant.PtpHttpHeader; -import com.osx.core.context.Context; + import com.osx.core.exceptions.ErrorMessageUtil; import com.osx.core.exceptions.ExceptionInfo; import com.osx.core.exceptions.ParameterException; -import com.osx.core.frame.Lifecycle; import com.osx.core.provider.TechProvider; import com.osx.core.ptp.TargetMethod; import com.osx.core.service.InboundPackage; import com.osx.core.service.OutboundPackage; import com.osx.core.service.ServiceAdaptor; +import com.osx.core.utils.FlowLogUtil; import io.grpc.stub.StreamObserver; import org.apache.commons.io.IOUtils; +import org.apache.commons.lang3.StringUtils; import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import javax.servlet.http.HttpServletRequest; @@ -55,84 +60,72 @@ * FATE 相关实现 */ -public class FateTechProvider implements TechProvider, Lifecycle { +public class FateTechProvider implements TechProvider { + Logger logger = LoggerFactory.getLogger(FateTechProvider.class); ConcurrentMap serviceAdaptorConcurrentMap = new ConcurrentHashMap<>(); + PcpHandleInterceptor requestHandleInterceptor; + TokenValidatorInterceptor tokenValidatorInterceptor; + RouterInterceptor routerInterceptor; + private Set httpAllowedMethod = Sets.newHashSet(TargetMethod.PRODUCE_MSG.name(), TargetMethod.UNARY_CALL.name()); + + public FateTechProvider() { + requestHandleInterceptor = new PcpHandleInterceptor(); + tokenValidatorInterceptor = new TokenValidatorInterceptor(); + routerInterceptor = new RouterInterceptor(); + registerServiceAdaptor(); + } - RequestHandleInterceptor requestHandleInterceptor; - RouterInterceptor routerInterceptor; - - private Set httpAllowedMethod= Sets.newHashSet(TargetMethod.PRODUCE_MSG.name(),TargetMethod.UNARY_CALL.name()); - - private void checkHttpAllowedMethod(String targetMethod){ - - if(!httpAllowedMethod.contains(targetMethod)){ - throw new ParameterException("target method :"+targetMethod+"is not allowed"); + private void checkHttpAllowedMethod(String targetMethod) { + if (!httpAllowedMethod.contains(targetMethod)) { + throw new ParameterException("target method :" + targetMethod + "is not allowed"); } - } @Override public void processHttpInvoke(HttpServletRequest request, HttpServletResponse response) { - Context context = ContextUtil.buildContext(); - Osx.Inbound.Builder inboundBuilder ; - ServiceAdaptor serviceAdaptor=null; + Context context = ContextUtil.buildFateContext(Protocol.http); + context.putData(Dict.HTTP_SERVLET_RESPONSE, response); + Osx.Inbound.Builder inboundBuilder; + ServiceAdaptor serviceAdaptor = null; try { - String Version = request.getHeader(PtpHttpHeader.Version); - String TechProviderCode = request.getHeader(PtpHttpHeader.TechProviderCode); - String TraceID = request.getHeader(PtpHttpHeader.TraceID); - String Token = request.getHeader(PtpHttpHeader.Token); - String SourceNodeID = request.getHeader(PtpHttpHeader.SourceNodeID); - String TargetNodeID = request.getHeader(PtpHttpHeader.TargetNodeID); - String SourceInstID = request.getHeader(PtpHttpHeader.SourceInstID); - String TargetInstID = request.getHeader(PtpHttpHeader.TargetInstID); - String SessionID = request.getHeader(PtpHttpHeader.SessionID); - String MessageTopic = request.getHeader(PtpHttpHeader.MessageTopic); - String MessageCode = request.getHeader(PtpHttpHeader.MessageCode); - String SourceComponentName = request.getHeader(PtpHttpHeader.SourceComponentName); - String TargetComponentName = request.getHeader(PtpHttpHeader.TargetComponentName); - String TargetMethod = request.getHeader(PtpHttpHeader.TargetMethod); - String MessageOffSet = request.getHeader(PtpHttpHeader.MessageOffSet); - String InstanceId = request.getHeader(PtpHttpHeader.InstanceId); - String Timestamp = request.getHeader(PtpHttpHeader.Timestamp); - context.setSrcPartyId(SourceNodeID); - context.setDesPartyId(TargetNodeID); - context.setSessionId(SessionID); - context.setTopic(MessageTopic); - context.setActionType(TargetMethod); - inboundBuilder = TransferUtil.buildPbFromHttpRequest(request); + inboundBuilder = TransferUtil.buildPbFromHttpRequest(context, request); String targetMethod = inboundBuilder.getMetadataMap().get(Osx.Metadata.TargetMethod.name()); - checkHttpAllowedMethod(TargetMethod); - serviceAdaptor = this.getServiceAdaptor(TargetMethod); + if (StringUtils.isEmpty(targetMethod)) { + throw new ParameterException("target method is null"); + } + checkHttpAllowedMethod(targetMethod); + serviceAdaptor = this.getServiceAdaptor(targetMethod); + byte[] buffer = new byte[MetaInfo.PROPERTY_HTTP_REQUEST_BODY_MAX_SIZE]; int length = IOUtils.read(request.getInputStream(), buffer); byte[] data = new byte[length]; System.arraycopy(buffer, 0, data, 0, length); inboundBuilder.setPayload(ByteString.copyFrom(data)); - }catch(Exception e){ - ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context,e); - this.writeHttpRespose(response, exceptionInfo.getCode(),exceptionInfo.getMessage(),null); + } catch (Exception e) { + ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context, e); + this.writeHttpRespose(response, exceptionInfo.getCode(), exceptionInfo.getMessage(), null); context.setReturnCode(exceptionInfo.getCode()); context.setReturnMsg(exceptionInfo.getMessage()); - context.printFlowLog(); - return ; + FlowLogUtil.printFlowLog(context); + return; } - InboundPackage inboundPackage = new InboundPackage(); - inboundPackage.setBody(inboundBuilder.build()); - OutboundPackage outboundPackage = serviceAdaptor.service(context, inboundPackage); - Osx.Outbound outbound = outboundPackage.getData(); - response.setContentType(Dict.CONTENT_TYPE_JSON_UTF8); - this.writeHttpRespose(response,outbound.getCode(),outbound.getMessage(),outbound.getPayload().toByteArray() ); + InboundPackage inboundPackage = new InboundPackage(); + inboundPackage.setBody(inboundBuilder.build()); + OutboundPackage outboundPackage = serviceAdaptor.service(context, inboundPackage); + Osx.Outbound outbound = outboundPackage.getData(); + response.setContentType(Dict.CONTENT_TYPE_JSON_UTF8); + TransferUtil.writeHttpRespose(response, outbound.getCode(), outbound.getMessage(), outbound.getPayload().toByteArray()); } - private void writeHttpRespose(HttpServletResponse response,String code, - String msg, - byte[] content){ + private void writeHttpRespose(HttpServletResponse response, String code, + String msg, + byte[] content) { try { - response.setHeader(PtpHttpHeader.ReturnCode,code); - response.setHeader(PtpHttpHeader.MessageCode,msg); - OutputStream outputStream = response.getOutputStream(); - if(content!=null) { + response.setHeader(PtpHttpHeader.ReturnCode, code); + response.setHeader(PtpHttpHeader.MessageCode, msg); + OutputStream outputStream = response.getOutputStream(); + if (content != null) { outputStream.write(content); } outputStream.flush(); @@ -144,8 +137,8 @@ private void writeHttpRespose(HttpServletResponse response,String code, @Override public void processGrpcInvoke(Osx.Inbound request, StreamObserver responseObserver) { - Context context = ContextUtil.buildContext(); - context.putData(Dict.RESPONSE_STREAM_OBSERVER,responseObserver); + Context context = ContextUtil.buildFateContext(Protocol.grpc); + context.putData(Dict.RESPONSE_STREAM_OBSERVER, responseObserver); Osx.Outbound result = null; try { Map metaDataMap = request.getMetadataMap(); @@ -160,79 +153,85 @@ public void processGrpcInvoke(Osx.Inbound request, StreamObserver if (outboundPackage.getData() != null) { result = outboundPackage.getData(); } - }catch (Exception e){ - ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context,e); + } catch (Exception e) { + ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context, e); //this.writeHttpRespose(response, exceptionInfo.getCode(),exceptionInfo.getMessage(),null); context.setReturnCode(exceptionInfo.getCode()); context.setReturnMsg(exceptionInfo.getMessage()); - context.printFlowLog(); + FlowLogUtil.printFlowLog(context); result = Osx.Outbound.newBuilder().setCode(exceptionInfo.getCode()).setMessage(exceptionInfo.getMessage()).build(); } - if(result!=null) { + if (result != null) { responseObserver.onNext(result); responseObserver.onCompleted(); } } - @Override - public String getProviderId() { - return MetaInfo.PROPERTY_FATE_TECH_PROVIDER; - } - @Override public StreamObserver processGrpcTransport(Osx.Inbound fristPackage, StreamObserver responseObserver) { Map metaDataMap = fristPackage.getMetadataMap(); String targetMethod = metaDataMap.get(Osx.Metadata.TargetMethod.name()); ServiceAdaptor serviceAdaptor = this.getServiceAdaptor(targetMethod); - if(serviceAdaptor==null){ - throw new ParameterException("invalid target method "+targetMethod); + if (serviceAdaptor == null) { + throw new ParameterException("invalid target method " + targetMethod); } - Context context = ContextUtil.buildContext(); + Context context = ContextUtil.buildFateContext(Protocol.grpc); InboundPackage inboundPackage = new InboundPackage(); inboundPackage.setBody(responseObserver); - OutboundPackage> outboundPackage = serviceAdaptor.service( context, inboundPackage); - if(outboundPackage!=null&&outboundPackage.getData()!=null){ - return (StreamObserver)outboundPackage.getData(); - }else{ + OutboundPackage> outboundPackage = serviceAdaptor.service(context, inboundPackage); + if (outboundPackage != null && outboundPackage.getData() != null) { + return (StreamObserver) outboundPackage.getData(); + } else { return null; } + } + @Override + public void processGrpcPeek(Osx.PeekInbound inbound, StreamObserver responseObserver) { } @Override - public void init() { - Preconditions.checkArgument(ServiceContainer.fateRouterService != null); - requestHandleInterceptor = new RequestHandleInterceptor(); - routerInterceptor =ServiceContainer.routerInterceptor; - registerServiceAdaptor(); + public void processGrpcPush(Osx.PushInbound inbound, StreamObserver responseObserver) { + } @Override - public void start() { + public void processGrpcPop(Osx.PopInbound inbound, StreamObserver responseObserver) { } @Override - public void destroy() { + public void processGrpcRelease(Osx.ReleaseInbound inbound, StreamObserver responseObserver) { } + + public ServiceAdaptor getServiceAdaptor(String name) { return this.serviceAdaptorConcurrentMap.get(name); } + private void registerServiceAdaptor() { - this.serviceAdaptorConcurrentMap.put(TargetMethod.UNARY_CALL.name(), new PtpUnaryCallService().addPreProcessor(requestHandleInterceptor) + this.serviceAdaptorConcurrentMap.put(TargetMethod.UNARY_CALL.name(), new PtpUnaryCallService() + .addPreProcessor(requestHandleInterceptor) + .addPreProcessor(tokenValidatorInterceptor) .addPreProcessor(routerInterceptor)); - this.serviceAdaptorConcurrentMap.put(TargetMethod.PRODUCE_MSG.name(), new PtpProduceService().addPreProcessor(requestHandleInterceptor) + this.serviceAdaptorConcurrentMap.put(TargetMethod.PRODUCE_MSG.name(), new PtpProduceService() + .addPreProcessor(requestHandleInterceptor) .addPreProcessor(routerInterceptor)); - this.serviceAdaptorConcurrentMap.put(TargetMethod.ACK_MSG.name(), new PtpAckService().addPreProcessor(requestHandleInterceptor)); - this.serviceAdaptorConcurrentMap.put(TargetMethod.CONSUME_MSG.name(), new PtpConsumeService().addPreProcessor(requestHandleInterceptor)); - this.serviceAdaptorConcurrentMap.put(TargetMethod.QUERY_TOPIC.name(), new PtpQueryTransferQueueService().addPreProcessor(requestHandleInterceptor)); - this.serviceAdaptorConcurrentMap.put(TargetMethod.CANCEL_TOPIC.name(), new PtpCancelTransferService().addPreProcessor(requestHandleInterceptor)); + this.serviceAdaptorConcurrentMap.put(TargetMethod.ACK_MSG.name(), new PtpAckService() + .addPreProcessor(requestHandleInterceptor)); + this.serviceAdaptorConcurrentMap.put(TargetMethod.CONSUME_MSG.name(), new PtpConsumeService() + .addPreProcessor(requestHandleInterceptor)); + this.serviceAdaptorConcurrentMap.put(TargetMethod.QUERY_TOPIC.name(), new PtpQueryTransferQueueService() + .addPreProcessor(requestHandleInterceptor)); + this.serviceAdaptorConcurrentMap.put(TargetMethod.CANCEL_TOPIC.name(), new PtpCancelTransferService() + .addPreProcessor(requestHandleInterceptor)); this.serviceAdaptorConcurrentMap.put(TargetMethod.PUSH.name(), new PtpPushService()); this.serviceAdaptorConcurrentMap.put(TargetMethod.APPLY_TOKEN.name(), new PtpClusterTokenApplyService()); - this.serviceAdaptorConcurrentMap.put(TargetMethod.APPLY_TOPIC.name(),new PtpClusterTopicApplyService()); + this.serviceAdaptorConcurrentMap.put(TargetMethod.APPLY_TOPIC.name(), new PtpClusterTopicApplyService()); + // this.serviceAdaptorConcurrentMap.put(TargetMethod.TEST_STREAM.name(), new PtpStreamTestService()); } } diff --git a/java/osx/broker/src/main/java/com/osx/tech/provider/TechProviderRegister.java b/java/osx/osx-broker/src/main/java/com/osx/tech/provider/TechProviderRegister.java similarity index 51% rename from java/osx/broker/src/main/java/com/osx/tech/provider/TechProviderRegister.java rename to java/osx/osx-broker/src/main/java/com/osx/tech/provider/TechProviderRegister.java index bd124e277d..c05a21b5c7 100644 --- a/java/osx/broker/src/main/java/com/osx/tech/provider/TechProviderRegister.java +++ b/java/osx/osx-broker/src/main/java/com/osx/tech/provider/TechProviderRegister.java @@ -15,11 +15,18 @@ */ package com.osx.tech.provider; -import com.google.common.base.Preconditions; +import com.osx.core.config.MetaInfo; +import com.osx.core.constant.Dict; +import com.osx.core.exceptions.ParameterException; import com.osx.core.frame.Lifecycle; import com.osx.core.provider.TechProvider; +import com.osx.core.utils.ClassUtils; +import com.osx.core.utils.PropertiesUtil; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -import java.util.Map; +import java.util.Properties; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -28,22 +35,35 @@ */ public class TechProviderRegister implements Lifecycle { + Logger logger = LoggerFactory.getLogger(TechProviderRegister.class); ConcurrentMap registerMap = new ConcurrentHashMap<>(); - + final String configFileName = "components/provider.properties"; + final public TechProvider select(String techProviderCode ) { - Preconditions.checkArgument(techProviderCode != null); + if(StringUtils.isEmpty(techProviderCode)){ + throw new ParameterException("techProviderCode is null"); + } return this.registerMap.get(techProviderCode); } public void init() { - FateTechProvider fateTechProvider = new FateTechProvider(); - fateTechProvider.init(); - this.registerMap.put(fateTechProvider.getProviderId(), fateTechProvider); + Properties properties = PropertiesUtil.getProperties(MetaInfo.PROPERTY_CONFIG_DIR+Dict.SLASH+Dict.SLASH+configFileName); + properties.forEach((k,v)->{ + try { + this.registerMap.put(k.toString(), (TechProvider) ClassUtils.newInstance(v.toString())); + }catch(Exception e){ + logger.error("provider {} class {} init error",k,v); + } + }); + logger.info("tech provider register : {}",this.registerMap); } + @Override public void start() { + init(); } @Override public void destroy() { + this.registerMap.clear(); } } diff --git a/java/osx/osx-broker/src/main/resources/broker/broker.properties b/java/osx/osx-broker/src/main/resources/broker/broker.properties new file mode 100644 index 0000000000..22228c5240 --- /dev/null +++ b/java/osx/osx-broker/src/main/resources/broker/broker.properties @@ -0,0 +1,56 @@ +grpc.port= 9370 +# Http switch for the server. +# If set to True, the server will open the http port. +# http port configuration can be set through http.port +open.http.server=false +# port of http +http.port=8087 +https.port=8088 +# whether the http server uses TLS +#ttp.use.tls = false +# whether the grpc server uses TLS? +# If true, a grpc port will be specially opened to listen for TLS requests +# grpc tls port configuration can be set through grpc.tls.port +open.grpc.tls.server=false +grpc.tls.port=9883 +# the partyId of self ,multiple partyIds can be set. +# eg: 9999,10000,10001 +self.party=9999 +# deployment mode, including cluster/standalone, +# respectively representing cluster mode and standalone mode , +# and standalone is used by default +deploy.mode=cluster +# the zookeeper address needs to be configured when the deployment mode is cluster +zk.url=127.0.0.1:2181 +# the IP of the cluster manager component of eggroll +eggroll.cluster.manager.ip = localhost +# the port of the cluster manager component of eggroll +eggroll.cluster.manager.port = 4670 +# maximum number of message retries +produce.msg.max.try.time =3 + +http.client.method.config = {"UNARY_CALL":{"reqTimeout":0,"connectionTimeout":0,"socketTimeout":0}} + +http.use.tls=false + +http.ssl.trust.store.type=PKCS12 + +http.ssl.key.store.alias=22 + +http.ssl.key.store.password=123456 + +#http.ssl.trust.store.path=D:\\44\\127.0.0.1.pfx + +server.ca.file= +server.cert.chain.file= +server.private.key.file= + + + + + + + + + + diff --git a/java/osx/broker/src/main/resources/flowRule.json b/java/osx/osx-broker/src/main/resources/broker/flowRule.json similarity index 100% rename from java/osx/broker/src/main/resources/flowRule.json rename to java/osx/osx-broker/src/main/resources/broker/flowRule.json diff --git a/java/osx/osx-broker/src/main/resources/broker/route_table.json b/java/osx/osx-broker/src/main/resources/broker/route_table.json new file mode 100644 index 0000000000..8c36e1acae --- /dev/null +++ b/java/osx/osx-broker/src/main/resources/broker/route_table.json @@ -0,0 +1,44 @@ +{ + "route_table": + { + "20008": + { + + "fateflow":[ + { + "port": 9360, + "ip": "127.0.0.1" + } + ] + }, + "10008":{ + "default":[{ + "protocol":"grpc", + "url": "https://127.0.0.1:8088/osx/inbound", + "certChainFile": "D:/22/174_x.crt", + "privateKeyFile": "D:/22/174_x.key", + "caFile": "D:/22/ca_x.crt", + "negotiationType": "TLS", + "useSSL": true, + "port": 9883, + "ip": "127.0.0.1" + }] + }, + "10001":{ + "default":[{ + "protocol":"http", + "url": "http://localhost:8222/osx/inbound" + }] + }, + "9999":{ + "default":[{ + "port": 9360, + "ip": "127.0.0.1" + }] + } + }, + "permission": + { + "default_allow": true + } +} diff --git a/java/osx/osx-broker/src/main/resources/components/provider.properties b/java/osx/osx-broker/src/main/resources/components/provider.properties new file mode 100644 index 0000000000..4c73df057f --- /dev/null +++ b/java/osx/osx-broker/src/main/resources/components/provider.properties @@ -0,0 +1,2 @@ +FATE=com.osx.tech.provider.FateTechProvider + diff --git a/java/osx/osx-broker/src/main/resources/components/router.properties b/java/osx/osx-broker/src/main/resources/components/router.properties new file mode 100644 index 0000000000..f9a11b588f --- /dev/null +++ b/java/osx/osx-broker/src/main/resources/components/router.properties @@ -0,0 +1 @@ +FATE=com.osx.broker.router.DefaultFateRouterServiceImpl \ No newline at end of file diff --git a/java/osx/osx-broker/src/main/resources/components/translator.properties b/java/osx/osx-broker/src/main/resources/components/translator.properties new file mode 100644 index 0000000000..153218757b --- /dev/null +++ b/java/osx/osx-broker/src/main/resources/components/translator.properties @@ -0,0 +1,2 @@ +9999-10000=com.osx.broker.demo.DemoTranslator +10000-9999=com.osx.broker.demo.DemoTranslator \ No newline at end of file diff --git a/java/osx/broker/src/main/resources/log4j2.xml b/java/osx/osx-broker/src/main/resources/log4j2.xml similarity index 100% rename from java/osx/broker/src/main/resources/log4j2.xml rename to java/osx/osx-broker/src/main/resources/log4j2.xml diff --git a/java/osx/broker/src/test/java/com/osx/broker/cluster/ClusterClientEndpointTest.java b/java/osx/osx-broker/src/test/java/com/osx/broker/cluster/ClusterClientEndpointTest.java similarity index 100% rename from java/osx/broker/src/test/java/com/osx/broker/cluster/ClusterClientEndpointTest.java rename to java/osx/osx-broker/src/test/java/com/osx/broker/cluster/ClusterClientEndpointTest.java diff --git a/java/osx/osx-broker/src/test/java/com/osx/broker/mock/MockHttpServer.java b/java/osx/osx-broker/src/test/java/com/osx/broker/mock/MockHttpServer.java new file mode 100644 index 0000000000..ba71921908 --- /dev/null +++ b/java/osx/osx-broker/src/test/java/com/osx/broker/mock/MockHttpServer.java @@ -0,0 +1,27 @@ +package com.osx.broker.mock; + +import com.osx.broker.ServiceContainer; +import com.osx.broker.server.OsxServer; +import com.osx.core.config.MetaInfo; + +import java.util.HashSet; + + +public class MockHttpServer { + + + + + + public static void main(String[] args){ + HashSet selfPartyIds = new HashSet(); + selfPartyIds.add("10001"); + MetaInfo.PROPERTY_SELF_PARTY= selfPartyIds; + MetaInfo.PROPERTY_GRPC_PORT=9372; + MetaInfo.PROPERTY_HTTP_PORT=8222; + MetaInfo.PROPERTY_OPEN_HTTP_SERVER = Boolean.TRUE; + ServiceContainer.init(); + } + + +} diff --git a/java/osx/broker/src/test/java/com/osx/broker/mock/MockServer.java b/java/osx/osx-broker/src/test/java/com/osx/broker/mock/MockServer.java similarity index 73% rename from java/osx/broker/src/test/java/com/osx/broker/mock/MockServer.java rename to java/osx/osx-broker/src/test/java/com/osx/broker/mock/MockServer.java index 424aa14ed6..46577f0ec1 100644 --- a/java/osx/broker/src/test/java/com/osx/broker/mock/MockServer.java +++ b/java/osx/osx-broker/src/test/java/com/osx/broker/mock/MockServer.java @@ -25,7 +25,6 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; -import static com.osx.broker.ServiceContainer.proxyGrpcService; public class MockServer { @@ -66,34 +65,34 @@ public static void main(String[] args) { } - private static Server buildServer() { - NettyServerBuilder nettyServerBuilder = (NettyServerBuilder) ServerBuilder.forPort(9375); - nettyServerBuilder.addService(ServerInterceptors.intercept(proxyGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); - // nettyServerBuilder.addService(ServerInterceptors.intercept(queueGrpcservice, new ServiceExceptionHandler(),new ContextPrepareInterceptor())); - // nettyServerBuilder.addService(ServerInterceptors.intercept(commonService, new ServiceExceptionHandler(),new ContextPrepareInterceptor())); - - nettyServerBuilder - .executor(Executors.newCachedThreadPool()) - .maxConcurrentCallsPerConnection(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION) - .maxInboundMessageSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE) - .maxInboundMetadataSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE) - .flowControlWindow(MetaInfo.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW); - if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC > 0) - nettyServerBuilder.keepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC > 0) - nettyServerBuilder.keepAliveTimeout(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC > 0) - nettyServerBuilder.permitKeepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED) - nettyServerBuilder.permitKeepAliveWithoutCalls(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED); - if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC > 0) - nettyServerBuilder.maxConnectionIdle(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC > 0) - nettyServerBuilder.maxConnectionAge(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC > 0) - nettyServerBuilder.maxConnectionAgeGrace(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC, TimeUnit.SECONDS); - return nettyServerBuilder.build(); - } +// private static Server buildServer() { +// NettyServerBuilder nettyServerBuilder = (NettyServerBuilder) ServerBuilder.forPort(9375); +// nettyServerBuilder.addService(ServerInterceptors.intercept(proxyGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); +// // nettyServerBuilder.addService(ServerInterceptors.intercept(queueGrpcservice, new ServiceExceptionHandler(),new ContextPrepareInterceptor())); +// // nettyServerBuilder.addService(ServerInterceptors.intercept(commonService, new ServiceExceptionHandler(),new ContextPrepareInterceptor())); +// +// nettyServerBuilder +// .executor(Executors.newCachedThreadPool()) +// .maxConcurrentCallsPerConnection(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION) +// .maxInboundMessageSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE) +// .maxInboundMetadataSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE) +// .flowControlWindow(MetaInfo.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW); +// if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC > 0) +// nettyServerBuilder.keepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); +// if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC > 0) +// nettyServerBuilder.keepAliveTimeout(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC, TimeUnit.SECONDS); +// if (MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC > 0) +// nettyServerBuilder.permitKeepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); +// if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED) +// nettyServerBuilder.permitKeepAliveWithoutCalls(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED); +// if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC > 0) +// nettyServerBuilder.maxConnectionIdle(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC, TimeUnit.SECONDS); +// if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC > 0) +// nettyServerBuilder.maxConnectionAge(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC, TimeUnit.SECONDS); +// if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC > 0) +// nettyServerBuilder.maxConnectionAgeGrace(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC, TimeUnit.SECONDS); +// return nettyServerBuilder.build(); +// } private static class PtpService extends PrivateTransferProtocolGrpc.PrivateTransferProtocolImplBase { diff --git a/java/osx/broker/src/test/java/com/osx/broker/test/grpc/EggrollTest.java b/java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/EggrollTest.java similarity index 100% rename from java/osx/broker/src/test/java/com/osx/broker/test/grpc/EggrollTest.java rename to java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/EggrollTest.java diff --git a/java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/Grpc_UC.java b/java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/Grpc_UC.java new file mode 100644 index 0000000000..ee40184aef --- /dev/null +++ b/java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/Grpc_UC.java @@ -0,0 +1,55 @@ +package com.osx.broker.test.grpc; + +import com.osx.api.router.RouterInfo; +import com.osx.core.constant.Dict; +import com.osx.core.constant.StatusCode; +import com.osx.core.context.FateContext; +import com.osx.core.exceptions.RemoteRpcException; +import com.osx.core.frame.GrpcConnectionFactory; +import com.osx.core.utils.JsonUtil; +import io.grpc.ManagedChannel; +import io.grpc.StatusRuntimeException; +import org.junit.Test; +import org.ppc.ptp.Osx; +import org.ppc.ptp.PrivateTransferProtocolGrpc; + +public class Grpc_UC { + + String contextStr = "{\"actionType\":\"unary-call-new\",\"protocol\":\"grpc\",\"techProviderCode\":\"FATE\",\"needCheckRouterInfo\":true,\"costTime\":0,\"resourceName\":\"I_unary-call-new\",\"timeStamp\":1685499290484,\"downstreamCost\":0,\"downstreamBegin\":0,\"destination\":false,\"sourceIp\":\"127.0.0.1\",\"desPartyId\":\"20008\",\"srcPartyId\":\"\",\"returnCode\":\"0\",\"desComponent\":\"fateflow\",\"routerInfo\":{\"protocol\":\"grpc\",\"sourcePartyId\":\"\",\"desPartyId\":\"20008\",\"desRole\":\"fateflow\",\"url\":\"\",\"host\":\"127.0.0.1\",\"port\":9360,\"useSSL\":false,\"negotiationType\":\"\",\"certChainFile\":\"\",\"privateKeyFile\":\"\",\"caFile\":\"\",\"resource\":\"-20008\",\"cycle\":false},\"selfPartyId\":\"10008\"}"; + String routerJson = "{\n" + + " \"protocol\": \"grpc\",\n" + + " \"sourcePartyId\": \"\",\n" + + " \"desPartyId\": \"10008\",\n" + + " \"desRole\": \"fateflow\",\n" + + " \"url\": \"http://127.0.0.1:8087/osx/inbound\",\n" + + " \"host\": \"127.0.0.1\",\n" + + " \"port\": 9883,\n" + + " \"useSSL\": true,\n" + + " \"negotiationType\": \"TLS\",\n" + + " \"certChainFile\": \"D:/33/127.0.0.1.crt\",\n" + + " \"privateKeyFile\": \"D:/33/127.0.0.1.key\",\n" + + " \"caFile\": \"D:/33/testRoot.crt\",\n" + + " \"resource\": \"-10008\",\n" + + " \"cycle\": false\n" + + "}"; + + @Test + public void run(){ + FateContext context = JsonUtil.json2Object(contextStr,FateContext.class); + RouterInfo routerInfo = JsonUtil.json2Object(routerJson,RouterInfo.class); + PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub stub = null; + if (context.getData(Dict.BLOCKING_STUB) == null) { + ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(routerInfo, true); + stub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); + } else { + stub = (PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub) context.getData(Dict.BLOCKING_STUB); + } + try { + // logger.info("===========send data {}",produceRequest); + Osx.Outbound invoke = stub.invoke(null); + } catch (StatusRuntimeException e) { + e.printStackTrace(); + throw new RemoteRpcException(StatusCode.NET_ERROR, "send to " + routerInfo.toKey() + " error : " + e.getMessage()); + } + } +} diff --git a/java/osx/broker/src/test/java/com/osx/broker/test/grpc/NewFateTest.java b/java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/NewFateTest.java similarity index 52% rename from java/osx/broker/src/test/java/com/osx/broker/test/grpc/NewFateTest.java rename to java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/NewFateTest.java index 9209eed62e..6de3f3b16e 100644 --- a/java/osx/broker/src/test/java/com/osx/broker/test/grpc/NewFateTest.java +++ b/java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/NewFateTest.java @@ -1,5 +1,6 @@ package com.osx.broker.test.grpc; +import com.google.protobuf.ByteString; import com.osx.core.config.MetaInfo; import com.osx.core.constant.Dict; import com.osx.core.ptp.TargetMethod; @@ -7,12 +8,15 @@ import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder; +import io.grpc.stub.StreamObserver; import org.junit.Before; import org.junit.Test; import org.ppc.ptp.Osx; import org.ppc.ptp.PrivateTransferProtocolGrpc; import java.io.File; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; public class NewFateTest { @@ -27,13 +31,16 @@ public class NewFateTest { String transferId = "testTransferId"; String sessionId = "testSessionId"; PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub blockingStub; - + PrivateTransferProtocolGrpc.PrivateTransferProtocolStub stub; @Before public void init() { ManagedChannel managedChannel = createManagedChannel(ip, port); // stub = PrivateTransferProtocolGrpc.newBlockingStub(); // ManagedChannel managedChannel2 = createManagedChannel(ip, port); blockingStub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); + stub = PrivateTransferProtocolGrpc.newStub(managedChannel); + + } public static ManagedChannel createManagedChannel(String ip, int port) { @@ -62,22 +69,80 @@ public static ManagedChannel createManagedChannel(String ip, int port) { } @Test - public void testUnaryCall(){ + public void testUnaryCall(byte[] data){ Osx.Inbound.Builder inboundBuilder = Osx.Inbound.newBuilder(); inboundBuilder.putMetadata(Osx.Header.Version.name(), "123"); - inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), MetaInfo.PROPERTY_FATE_TECH_PROVIDER); + inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), "FATE");// inboundBuilder.putMetadata(Osx.Header.Token.name(), "testToken"); inboundBuilder.putMetadata(Osx.Header.SourceNodeID.name(), "9999"); inboundBuilder.putMetadata(Osx.Header.TargetNodeID.name(), "10000"); - inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), ""); - inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), ""); + inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), "9999"); + inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), "10000"); inboundBuilder.putMetadata(Osx.Header.SessionID.name(), "testSessionID"); - inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), TargetMethod.UNARY_CALL.name()); + inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), "UNARY_CALL"); inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), "fateflow"); inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), ""); - // inboundBuilder.putMetadata(Osx.Metadata.MessageTopic.name(), transferId); + inboundBuilder.putMetadata(Osx.Header.TraceID.name(), "28938999993"); + inboundBuilder.setPayload(ByteString.copyFrom(data)); Osx.Outbound outbound = blockingStub.invoke(inboundBuilder.build()); System.err.println("response : "+outbound); } + + @Test + public void testStream(){ + System.err.println("=========================="); + io.grpc.stub.StreamObserver reqSb = stub.transport(new StreamObserver() { + @Override + public void onNext(Osx.Outbound outbound) { + System.err.println(outbound); + } + @Override + public void onError(Throwable throwable) { + throwable.printStackTrace(); + } + @Override + public void onCompleted() { + System.err.println("completed"); + } + }); + for(int i=0;i<3;i++){ + Osx.Inbound.Builder inboundBuilder = Osx.Inbound.newBuilder(); + inboundBuilder.putMetadata(Osx.Header.Version.name(), "123"); + inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), MetaInfo.PROPERTY_FATE_TECH_PROVIDER); + inboundBuilder.putMetadata(Osx.Header.Token.name(), "testToken"); + inboundBuilder.putMetadata(Osx.Header.SourceNodeID.name(), "9999"); + inboundBuilder.putMetadata(Osx.Header.TargetNodeID.name(), "10000"); + inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), ""); + inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), ""); + inboundBuilder.putMetadata(Osx.Header.SessionID.name(), "testSessionID"); + inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), TargetMethod.TEST_STREAM.name()); + inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), ""); + inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), ""); + // inboundBuilder.putMetadata(Osx.Metadata.MessageTopic.name(), transferId); + + inboundBuilder.setPayload(ByteString.copyFrom(("test "+i).getBytes(StandardCharsets.UTF_8))); + reqSb.onNext(inboundBuilder.build()); + } + + System.err.println("=========================="); + + } + + + public static void main(String[] args) { + System.err.println("==============="); + NewFateTest newFateTest = new NewFateTest(); + newFateTest.init(); + newFateTest.testStream(); + + CountDownLatch countDownLatch = new CountDownLatch(1); + try { + countDownLatch.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + } + } diff --git a/java/osx/broker/src/test/java/com/osx/broker/test/grpc/OldFateTest.java b/java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/OldFateTest.java similarity index 89% rename from java/osx/broker/src/test/java/com/osx/broker/test/grpc/OldFateTest.java rename to java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/OldFateTest.java index 61a86498fd..bea2ed8446 100644 --- a/java/osx/broker/src/test/java/com/osx/broker/test/grpc/OldFateTest.java +++ b/java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/OldFateTest.java @@ -18,9 +18,9 @@ public class OldFateTest { - static int port = 9370;//9371 - // static String ip = "localhost"; - static String ip = "10.42.0.85"; + static int port = 9371;//9371 + static String ip = "localhost"; + static Logger logger = LoggerFactory.getLogger(OldFateTest.class); @@ -148,12 +148,25 @@ public void onCompleted() { // } // for (int t = 0; t < 1; t++) { + String srcPartyId = "10000"; + String desPartyId = "9999"; // new Thread(() -> { StreamObserver requestOb = stub.push(responseOb); for (int i = 0; i < 3; i++) { + +// Proxy.Metadata metadata = packet.getHeader(); +// ByteString encodedRollSiteHeader = metadata.getExt(); + Transfer.RollSiteHeader.Builder rollSiteHeader = Transfer.RollSiteHeader.newBuilder(); + rollSiteHeader.setDstRole("desRole"); + rollSiteHeader.setDstPartyId(desPartyId); + rollSiteHeader.setSrcPartyId(srcPartyId); + rollSiteHeader.setSrcRole("srcRole"); Proxy.Packet.Builder packetBuilder = Proxy.Packet.newBuilder(); - packetBuilder.setHeader(Proxy.Metadata.newBuilder().setSrc(Proxy.Topic.newBuilder().setPartyId("9999")).setDst(Proxy.Topic.newBuilder().setPartyId("10000").setName("kaidengTestTopic").build()).build()); + packetBuilder.setHeader(Proxy.Metadata.newBuilder().setSrc(Proxy.Topic.newBuilder().setPartyId("10000")) + .setDst(Proxy.Topic.newBuilder().setPartyId("9999").setName("kaidengTestTopic").build()) + .setExt(rollSiteHeader.build().toByteString()) + .build()); // Transfer.RollSiteHeader.Builder headerBuilder = Transfer.RollSiteHeader.newBuilder(); // headerBuilder.setDstPartyId("10000"); // packetBuilder.setHeader(Proxy.Metadata.newBuilder().setExt(headerBuilder.build().toByteString())); @@ -180,6 +193,7 @@ public void onCompleted() { public static void main(String[] args) { System.err.println("==============="); + //testPush(); testUnaryCall(); CountDownLatch countDownLatch = new CountDownLatch(1); try { diff --git a/java/osx/broker/src/test/java/com/osx/broker/test/grpc/QueueTest.java b/java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/QueueTest.java similarity index 72% rename from java/osx/broker/src/test/java/com/osx/broker/test/grpc/QueueTest.java rename to java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/QueueTest.java index cb87acd659..0a5fe211df 100644 --- a/java/osx/broker/src/test/java/com/osx/broker/test/grpc/QueueTest.java +++ b/java/osx/osx-broker/src/test/java/com/osx/broker/test/grpc/QueueTest.java @@ -4,11 +4,14 @@ import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; +import com.osx.api.router.RouterInfo; +import com.osx.broker.util.TransferUtil; import com.osx.core.config.MetaInfo; import com.osx.core.constant.Dict; +import com.osx.core.context.FateContext; import com.osx.core.frame.GrpcConnectionFactory; import com.osx.core.ptp.TargetMethod; -import com.osx.core.router.RouterInfo; + import com.osx.core.utils.JsonUtil; import io.grpc.ManagedChannel; import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; @@ -30,12 +33,18 @@ public class QueueTest { String ip = "localhost"; //int port = 8250;//nginx int port = 9370;//nginx - String desPartyId = "10000"; + String desPartyId = "9999"; String desRole = ""; - String srcPartyId = "9999"; + String srcPartyId = "10000"; String srcRole = ""; String transferId = "testTransferId"; String sessionId = "testSessionId"; + + //4359615 + + + + PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub blockingStub; // FireworkQueueServiceGrpc.FireworkQueueServiceBlockingStub blockingStub; @@ -64,10 +73,10 @@ public static ManagedChannel createManagedChannel(String ip, int port) { @Before public void init() { - ManagedChannel managedChannel = createManagedChannel(ip, port); - // stub = PrivateTransferProtocolGrpc.newBlockingStub(); - ManagedChannel managedChannel2 = createManagedChannel(ip, port); - blockingStub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel2); +// ManagedChannel managedChannel = createManagedChannel(ip, port); +// // stub = PrivateTransferProtocolGrpc.newBlockingStub(); +// ManagedChannel managedChannel2 = createManagedChannel(ip, port); +// blockingStub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel2); } @@ -88,9 +97,12 @@ public void test02Query() { inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), ""); inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), ""); inboundBuilder.putMetadata(Osx.Metadata.MessageTopic.name(), transferId); - - - Osx.Outbound outbound = blockingStub.invoke(inboundBuilder.build()); + FateContext fateContext= new FateContext(); + RouterInfo routerInfo= new RouterInfo(); + routerInfo.setHost("localhost"); + routerInfo.setPort(9370); + Osx.Outbound outbound =TransferUtil.redirect(fateContext,inboundBuilder.build(),routerInfo,false); + // Osx.Outbound outbound = blockingStub.invoke(inboundBuilder.build()); Osx.TopicInfo topicInfo = null; try { topicInfo = Osx.TopicInfo.parseFrom(outbound.getPayload()); @@ -107,36 +119,71 @@ public void test02Query() { } + public void testUnaryConsume(){ + + } + + + private byte[] createBigArray(int size){ + byte[] result = new byte[size]; + for(int i=0;i body = new HashMap<>(); + body.put("uri", "/v2/partner/job/resource/apply"); + body.put("json_body", "{role=host, party_id=10008, job_id=202305251708508595320}"); + body.put("headers", "{}"); + body.put("method", "POST"); + body.put("MessageCode", "111"); + body.put("RetryCount", "111"); + return JsonUtil.object2Json(body); + } + + public Map buildHead() { + Map head = new HashMap<>(); +// CONSUME_MSG -> com.osx.broker.ptp.PtpConsumeService +// APPLY_TOPIC -> com.osx.broker.ptp.PtpClusterTopicApplyService +// APPLY_TOKEN -> com.osx.broker.ptp.PtpClusterTokenApplyService +// QUERY_TOPIC -> com.osx.broker.ptp.PtpQueryTransferQueueService +// PRODUCE_MSG -> com.osx.broker.ptp.PtpProduceService +// ACK_MSG -> com.osx.broker.ptp.PtpAckService +// UNARY_CALL -> com.osx.broker.ptp.PtpUnaryCallService +// CANCEL_TOPIC -> com.osx.broker.ptp.PtpCancelTransferService +// PUSH -> com.osx.broker.ptp.PtpPushService + head.put("x-ptp-target-method", "PRODUCE_MSG"); + head.put("x-ptp-job-id", "202305251708508595320"); + head.put("x-ptp-tech-provider-code", "FATE"); + head.put("x-ptp-message-offset", ""); + head.put("x-ptp-source-inst-id", ""); + head.put("x-ptp-timestamp", ""); + head.put("x-ptp-target-component-name", "fateflow"); + head.put("x-ptp-message-topic", ""); + head.put("x-ptp-trace-id", ""); + head.put("x-ptp-source-node-id", ""); + head.put("x-ptp-source-method", ""); + head.put("x-ptp-token", ""); + head.put("x-ptp-message-flag", ""); + head.put("x-ptp-version", ""); + head.put("x-ptp-source-component-name", ""); + head.put("x-ptp-session-id", ""); + head.put("x-ptp-instance-id", ""); + head.put("x-ptp-target-node-id", "10008"); + head.put("x-ptp-target-inst-id", ""); + + head.put(PtpHttpHeader.SessionID, "111"); + head.put(PtpHttpHeader.MessageTopic, "111"); + head.put(Osx.Metadata.MessageCode.name(), "111"); + head.put(Osx.Metadata.RetryCount.name(), "111"); + return head; + } +} diff --git a/java/osx/osx-broker/src/test/java/com/osx/broker/test/utils/JsonToMapCode.java b/java/osx/osx-broker/src/test/java/com/osx/broker/test/utils/JsonToMapCode.java new file mode 100644 index 0000000000..600bb1a8bc --- /dev/null +++ b/java/osx/osx-broker/src/test/java/com/osx/broker/test/utils/JsonToMapCode.java @@ -0,0 +1,28 @@ +package com.osx.broker.test.utils; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.osx.core.utils.JsonUtil; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +/** + * @date 2023/5/29 + * @remark + */ +public class JsonToMapCode { + + String json = "{\"uri\": \"/v2/partner/job/resource/apply\", \"json_body\": {\"role\": \"host\", \"party_id\": \"10008\", \"job_id\": \"202305251708508595320\"}, \"headers\": {}, \"method\": \"POST\"}"; + + @Test + public void run(){ + Map head = JsonUtil.json2Object(json, new TypeReference>() { + }); + StringBuffer sb = new StringBuffer(); + head.forEach((k,v)->{ + sb.append("head.put(\"").append(k).append("\",\"").append(v).append("\");").append("\n"); + }); + System.out.println("sb = " + sb); + } +} diff --git a/java/osx/core/pom.xml b/java/osx/osx-core/pom.xml similarity index 93% rename from java/osx/core/pom.xml rename to java/osx/osx-core/pom.xml index 75cfc948d1..9506fbfe57 100644 --- a/java/osx/core/pom.xml +++ b/java/osx/osx-core/pom.xml @@ -9,7 +9,7 @@ 4.0.0 - core + osx-core 8 @@ -17,6 +17,11 @@ + + osx + osx-api + ${osx.version} + org.slf4j slf4j-api @@ -82,6 +87,7 @@ commons-io commons-io + diff --git a/java/osx/osx-core/src/main/java/com/osx/core/config/Config.java b/java/osx/osx-core/src/main/java/com/osx/core/config/Config.java new file mode 100644 index 0000000000..8ed09b8c88 --- /dev/null +++ b/java/osx/osx-core/src/main/java/com/osx/core/config/Config.java @@ -0,0 +1,20 @@ +package com.osx.core.config; + +import java.lang.annotation.Inherited; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.FIELD; + +@Target({FIELD}) +@Retention(RetentionPolicy.RUNTIME) +@Inherited +public @interface Config { + + String pattern() default ""; +// String defaultValue() default ""; + String confKey(); + + +} diff --git a/java/osx/core/src/main/java/com/osx/core/config/GrpcChannelInfo.java b/java/osx/osx-core/src/main/java/com/osx/core/config/GrpcChannelInfo.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/config/GrpcChannelInfo.java rename to java/osx/osx-core/src/main/java/com/osx/core/config/GrpcChannelInfo.java diff --git a/java/osx/core/src/main/java/com/osx/core/config/MasterInfo.java b/java/osx/osx-core/src/main/java/com/osx/core/config/MasterInfo.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/config/MasterInfo.java rename to java/osx/osx-core/src/main/java/com/osx/core/config/MasterInfo.java diff --git a/java/osx/osx-core/src/main/java/com/osx/core/config/MetaInfo.java b/java/osx/osx-core/src/main/java/com/osx/core/config/MetaInfo.java new file mode 100644 index 0000000000..d5ef25d365 --- /dev/null +++ b/java/osx/osx-core/src/main/java/com/osx/core/config/MetaInfo.java @@ -0,0 +1,338 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.osx.core.config; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import com.osx.core.constant.DeployMode; +import com.osx.core.constant.Dict; +import com.osx.core.constant.StreamLimitMode; +import com.osx.core.exceptions.ConfigErrorException; +import com.osx.core.utils.JsonUtil; +import com.osx.core.utils.NetUtils; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.Field; +import java.util.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class MetaInfo { + + static Logger logger = LoggerFactory.getLogger(MetaInfo.class); + + @Config(confKey = "user.home") + public static String PROPERTY_USER_HOME = System.getProperty("user.home"); + @Config(confKey = "user.dir") + public static String PROPERTY_USER_DIR = System.getProperty("user.dir"); + + public static String CURRENT_VERSION = "100"; + @Config(confKey = "fate.tech.provider") + public static String PROPERTY_FATE_TECH_PROVIDER = "FATE"; + @Config(confKey = "default.client.version") + public static String PROPERTY_DEFAULT_CLIENT_VERSION = "2.X.X"; + public static volatile MasterInfo masterInfo; + @Config(confKey = "grpc.server.max.concurrent.call.per.connection", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION = 1000; + @Config(confKey = "grpc.server.max.inbound.metadata.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE = 128 << 20; + @Config(confKey = "grpc.server.max.inbound.message.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE = (2 << 30) - 1; + @Config(confKey = "grpc.server.flow.control.window", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW = 128 << 20; + @Config(confKey = "grpc.server.keepalive.time.sec", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC = 7200; + @Config(confKey = "grpc.server.keepalive.timeout.sec", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC = 3600; + @Config(confKey = "grpc.server.permit.keepalive.time.sec", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC = 10; + @Config(confKey = "grpc.server.keepalive.without.calls.enabled", pattern = Dict.BOOLEAN_PATTERN) + public static boolean PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED = true; + @Config(confKey = "grpc.server.max.connection.idle.sec", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC = 86400; + @Config(confKey = "grpc.server.max.connection.age.sec", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC = 86400; + @Config(confKey = "grpc.server.max.connection.age.grace.sec", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC = 86400; + @Config(confKey = "grpc.oncompleted.wait.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_ONCOMPLETED_WAIT_TIMEOUT = 600; + @Config(confKey = "grpc.client.max.inbound.message.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_CLIENT_MAX_INBOUND_MESSAGE_SIZE = (2 << 30) - 1; + @Config(confKey = "grpc.client.flow.control.window", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_CLIENT_FLOW_CONTROL_WINDOW = 128 << 20; + @Config(confKey = "grpc.client.keepalive.time", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_CLIENT_KEEPALIVE_TIME_SEC = 7200; + @Config(confKey = "grpc.client.keepalive.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_CLIENT_KEEPALIVE_TIMEOUT_SEC = 3600; + @Config(confKey = "grpc.client.keepalive.without.calls.enabled", pattern = Dict.BOOLEAN_PATTERN) + public static boolean PROPERTY_GRPC_CLIENT_KEEPALIVE_WITHOUT_CALLS_ENABLED = true; + @Config(confKey = "grpc.client.max.connection.idle", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_CLIENT_MAX_CONNECTION_IDLE_SEC = 86400; + @Config(confKey = "grpc.client.per.rpc.buffer.limit", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_CLIENT_PER_RPC_BUFFER_LIMIT = (2 << 30) - 1; + @Config(confKey = "grpc.client.retry.buffer.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_GRPC_CLIENT_RETRY_BUFFER_SIZE = 86400; + @Config(confKey = "transfer.cached.msgid.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_TRANSFER_CACHED_MSGID_SIZE = 10; + @Config(confKey = "grpc.ssl.session.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_SSL_SESSION_TIME_OUT = 3600 << 4; + @Config(confKey = "grpc.ssl.session.cache.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_SSL_SESSION_CACHE_SIZE = 65536; + + @Config(confKey = "mapped.file.expire.time", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_MAPPED_FILE_EXPIRE_TIME = 3600 * 1000 * 36; + @Config(confKey = "mapped.file.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int MAP_FILE_SIZE = 1 << 25; + @Config(confKey = "mapped.file.dir") + public static String PROPERTY_TRANSFER_FILE_PATH_PRE = "mapped/.fate/transfer_file"; + + @Config(confKey = "index.mapped.file.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_INDEX_MAP_FILE_SIZE = 1 << 21; + @Config(confKey = "server.cert.chain.file") + public static String PROPERTY_SERVER_CERT_CHAIN_FILE; + @Config(confKey = "server.private.key.file") + public static String PROPERTY_SERVER_PRIVATE_KEY_FILE; + @Config(confKey = "server.ca.file") + public static String PROPERTY_SERVER_CA_FILE; + @Config(confKey = "custom.local.host") + public static String PROPERTY_CUSTOMER_LOCAL_HOST; + @Config(confKey = "bind.host") + public static String PROPERTY_BIND_HOST = "0.0.0.0"; + @Config(confKey = "open.grpc.tls.server", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_OPEN_GRPC_TLS_SERVER = false; + @Config(confKey = "grpc.port", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_PORT = 9370; + @Config(confKey = "grpc.tls.port", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_TLS_PORT; + @Config(confKey = "use.remote.health.check", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_USE_REMOTE_HEALTH_CHECK = true; + @Config(confKey = "http.port", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_PORT; + @Config(confKey = "https.port", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTPS_PORT; + @Config(confKey = "open.http.server", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_OPEN_HTTP_SERVER = false; + @Config(confKey = "http.use.tls", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_HTTP_USE_TLS = false; + @Config(confKey = "http.server.acceptor.num", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_HTTP_SERVER_ACCEPTOR_NUM = 10; + @Config(confKey = "http.server.selector.num", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_HTTP_SERVER_SELECTOR_NUM = 1; + @Config(confKey = "http.ssl.trust.store.type") + public static String PROPERTY_HTTP_SSL_TRUST_STORE_TYPE = "PKCS12"; + @Config(confKey = "http.ssl.trust.store.provider") + public static String PROPERTY_HTTP_SSL_TRUST_STORE_PROVIDER = "SUN"; + @Config(confKey = "http.ssl.key.store.alias") + public static String PROPERTY_HTTP_SSL_KEY_STORE_ALIAS = ""; + @Config(confKey = "http.ssl.key.store.password") + public static String PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD = ""; + @Config(confKey = "http.ssl.trust.store.password") + public static String PROPERTY_HTTP_SSL_TRUST_STORE_PASSWORD = ""; + @Config(confKey = "http.ssl.trust.store.path") + public static String PROPERTY_HTTP_SSL_TRUST_STORE_PATH = ""; + @Config(confKey = "http.ssl.hostname.verify") + public static Boolean PROPERTY_HTTP_SSL_HOSTNAME_VERIFY = false; + + @Config(confKey = "http.request.body.max.size") + public static int PROPERTY_HTTP_REQUEST_BODY_MAX_SIZE = 32 * 1024 * 1024; + @Config(confKey = "http.context.path") + public static String PROPERTY_HTTP_CONTEXT_PATH = "/osx"; + @Config(confKey = "http.servlet.path") + public static String PROPERTY_HTTP_SERVLET_PATH = "/inbound"; + @Config(confKey = "http.receive.queue.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_HTTP_RECEIVE_QUEUE_SIZE = 36; + @Config(confKey = "http.accept.receive.buffer.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_HTTP_ACCEPT_RECEIVE_BUFFER_SIZE = 4096; + @Config(confKey = "zk.url") + public static String PROPERTY_ZK_URL; + @Config(confKey = "stream.limit.max.try.time", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_STREAM_LIMIT_MAX_TRY_TIME = 3; + @Config(confKey = "produce.msg.max.try.time", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_PRODUCE_MSG_MAX_TRY_TIME = 3; + @Config(confKey = "produce.msg.max.try.interval", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_PRODUCE_MSG_RETRY_INTERVAL = 100; + + @Config(confKey = "produce.msg.cache.max.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PRODUCE_MSG_CACHE_MAX_SIZE = 1000; + @Config(confKey = "produce.msg.cache.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PRODUCE_MSG_CACHE_TIMEOUT; + + + @Config(confKey = "flow.control.sample.count", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_FLOW_CONTROL_SAMPLE_COUNT = 10; + @Config(confKey = "flow.control.sample.interval", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_FLOW_CONTROL_SAMPLE_INTERVAL = 1000; + @Config(confKey = "stream.limit.mode") + public static String PROPERTY_STREAM_LIMIT_MODE = StreamLimitMode.NOLIMIT.name(); + @Config(confKey = "deploy.mode") + public static String PROPERTY_DEPLOY_MODE = DeployMode.standalone.name(); + @Config(confKey = "self.party") + public static Set PROPERTY_SELF_PARTY = Sets.newHashSet();// + @Config(confKey = "flow.rule") + public static String PROPERTY_FLOW_RULE_TABLE = "broker/flowRule.json"; + @Config(confKey = "use.zookeeper", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_USE_ZOOKEEPER = true; + @Config(confKey = "open.route.cycle.checker", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_OPEN_ROUTE_CYCLE_CHECKER = false; + + @Config(confKey = "zookeeper.acl.enable", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_ACL_ENABLE = false; + @Config(confKey = "zookeeper.acl.username") + public static String PROPERTY_ACL_USERNAME; + @Config(confKey = "zookeeper.acl.password") + public static String PROPERTY_ACL_PASSWORD; + @Config(confKey = "queue.max.free.time", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_QUEUE_MAX_FREE_TIME = 60000000; + @Config(confKey = "queue.check.interval", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static int PROPERTY_TRANSFER_QUEUE_CHECK_INTERVAL = 60 * 1000 * 10; + public static String INSTANCE_ID = NetUtils.getLocalHost() + ":" + MetaInfo.PROPERTY_GRPC_PORT; + + + + + @Config(confKey = "eggroll.cluster.manager.ip") + public static String PROPERTY_EGGROLL_CLUSTER_MANANGER_IP; + @Config(confKey = "eggroll.cluster.manager.port", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_EGGROLL_CLUSTER_MANANGER_PORT; + + + /** + * 从连接池中申请连接的超时时间 + */ + @Config(confKey = "http.client.method.config") + public static Map> PROPERTY_HTTP_CLIENT_METHOD_CONFIG_MAP =new HashMap<>(); + + @Config(confKey = "http.client.con.req.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT = 500; + /** + * 建立连接的超时时间 + */ + @Config(confKey = "http.client.connection.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_CLIENT_CONFIG_CONN_TIME_OUT = 10000; + + @Config(confKey = "http.client.max.idle.time", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_CLIENT_MAX_IDLE_TIME = 5; + /** + * 等待数据 + */ + @Config(confKey = "http.client.socket.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_CLIENT_CONFIG_SOCK_TIME_OUT = 300000; + @Config(confKey = "http.ssl.session.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_SSL_SESSION_TIME_OUT = 3600 << 4; + @Config(confKey = "http.client.pool.max.total", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_CLIENT_INIT_POOL_MAX_TOTAL = 500; + @Config(confKey = "http.client.pool.max.per.router", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE = 200; + @Config(confKey = "open.token.validator", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_OPEN_TOKEN_VALIDATOR = false; + @Config(confKey = "open.token.generator", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_OPEN_TOKEN_GENERATOR = false; + + public static String PROPERTY_TOKEN_GENERATOR_CONFIG_PATH; + public static String PROPERTY_CONFIG_DIR; + + + public static boolean isCluster() { + return PROPERTY_DEPLOY_MODE.equals(DeployMode.cluster.name()); + } + + + public static boolean checkPattern(String pattern, String value) { + Pattern p = Pattern.compile(pattern); + Matcher m = p.matcher(value); + if (m.find()) { + return true; + } else { + return false; + } + } + + public static void init(Properties environment) { + Field[] fields = MetaInfo.class.getFields(); + Arrays.stream(fields).forEach(field -> { + try { + Config config = field.getDeclaredAnnotation(Config.class); + if (config != null) { + Class clazz = field.getType(); + String confKey = config.confKey(); + Object value = environment.get(confKey); + if (value != null) { + String pattern = config.pattern(); + if (StringUtils.isNotEmpty(pattern) && !checkPattern(pattern, value.toString())) { + logger.error("conf {} has wrong value {},please check config file", confKey, value); + throw new ConfigErrorException("conf " + confKey + " has wrong value : " + value); + } + if (clazz == Integer.class) { + field.set(null, Integer.parseInt(value.toString())); + } else if (clazz == Long.class) { + field.set(null, Long.parseLong(value.toString())); + } else if (clazz == String.class) { + field.set(null, value.toString()); + + } else if (clazz == Boolean.class) { + field.set(null, Boolean.valueOf(value.toString())); + } else if (clazz.isAssignableFrom(Set.class)) { + Set set = new HashSet(); + set.addAll(Lists.newArrayList(value.toString().split(","))); + field.set(null, set); + } else if (clazz.isAssignableFrom(Map.class)) { + + Map> conConfig = JsonUtil.object2Objcet(value, new TypeReference>>() { + }); + field.set(null,conConfig); + } + } + if (StringUtils.isNotEmpty(confKey)) { + logger.info("{}={} ", confKey, field.get(null)); + } + } + } catch (Exception e) { + // e.printStackTrace(); + logger.error("parse config error",e); + throw new ConfigErrorException("parse config error: "+e.getMessage()); + } + }); + } + + + public static Map toMap() { + Map result = Maps.newHashMap(); + Field[] fields = MetaInfo.class.getFields(); + + for (Field field : fields) { + try { + if (field.get(MetaInfo.class) != null) { + String key = Dict.class.getField(field.getName()) != null ? String.valueOf(Dict.class.getField(field.getName()).get(Dict.class)) : field.getName(); + result.put(key, field.get(MetaInfo.class)); + } + } catch (IllegalAccessException | NoSuchFieldException e) { + + } + } + return result; + } + + public static void main(String args){ + + System.err.println( (2 << 30) - 1); + } + +} diff --git a/java/osx/osx-core/src/main/java/com/osx/core/config/TransferMeta.java b/java/osx/osx-core/src/main/java/com/osx/core/config/TransferMeta.java new file mode 100644 index 0000000000..adf0abf857 --- /dev/null +++ b/java/osx/osx-core/src/main/java/com/osx/core/config/TransferMeta.java @@ -0,0 +1,17 @@ +package com.osx.core.config; + +import lombok.Data; + +@Data +public class TransferMeta { + + String srcPartyId; + String desPartyId; + String srcRole; + String desRole; + String sessionId; + String topic; + + + +} diff --git a/java/osx/core/src/main/java/com/osx/core/constant/ActionType.java b/java/osx/osx-core/src/main/java/com/osx/core/constant/ActionType.java similarity index 97% rename from java/osx/core/src/main/java/com/osx/core/constant/ActionType.java rename to java/osx/osx-core/src/main/java/com/osx/core/constant/ActionType.java index 07310638a4..868d50d272 100644 --- a/java/osx/core/src/main/java/com/osx/core/constant/ActionType.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/constant/ActionType.java @@ -28,6 +28,7 @@ public enum ActionType { INNER_REDIRECT("inner-redirect"), LONG_PULLING_ANSWER("long-pulling-answer"), MSG_DOWNLOAD("msg-download"), + MSG_REDIRECT("msg-redirect"), REDIRECT_ACK("redirect-ack"), UNARY_CALL("unary-call"), UNARY_CALL_NEW("unary-call-new"), diff --git a/java/osx/core/src/main/java/com/osx/core/constant/DeployMode.java b/java/osx/osx-core/src/main/java/com/osx/core/constant/DeployMode.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/constant/DeployMode.java rename to java/osx/osx-core/src/main/java/com/osx/core/constant/DeployMode.java diff --git a/java/osx/core/src/main/java/com/osx/core/constant/Dict.java b/java/osx/osx-core/src/main/java/com/osx/core/constant/Dict.java similarity index 55% rename from java/osx/core/src/main/java/com/osx/core/constant/Dict.java rename to java/osx/osx-core/src/main/java/com/osx/core/constant/Dict.java index ae163bc606..86cc5cc297 100644 --- a/java/osx/core/src/main/java/com/osx/core/constant/Dict.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/constant/Dict.java @@ -19,7 +19,7 @@ import com.osx.core.config.MetaInfo; public class Dict { - public static final String PROPERTY_FATE_TECH_PROVIDER="fate.tech.provider"; + public static final String ORIGIN_REQUEST = "origin_request"; public static final String CASEID = "caseid"; public static final String SEQNO = "seqno"; @@ -29,11 +29,9 @@ public class Dict { public static final String HTTP_PORT ="http.port"; public static final String INSTANCE_ID = "instanceId"; - public static final String HIT_CACHE = "hitCache"; - - + public static final String POSITIVE_INTEGER_PATTERN = "^[1-9]\\d*$"; + public static final String BOOLEAN_PATTERN="^(true)|(false)$"; public static final String REQUEST_SEQNO = "REQUEST_SEQNO"; - public static final String VERSION = "version"; public static final String GRPC_TYPE = "grpcType"; public static final String ROUTER_INFO = "routerInfo"; @@ -48,67 +46,18 @@ public class Dict { public static final String DOWN_STREAM_BEGIN = "downstreamBegin"; public static final String ROUTE_BASIS = "routeBasis"; public static final String SOURCE_IP = "sourceIp"; - public static final String PROPERTY_SERVING_CORE_POOL_SIZE = "serving.core.pool.size"; - public static final String SERVING_MAX_POOL_ZIE = "serving.max.pool.size"; - public static final String PROPERTY_SERVING_POOL_ALIVE_TIME = "serving.pool.alive.time"; - public static final String PROPERTY_SERVING_POOL_QUEUE_SIZE = "serving.pool.queue.size"; + //HttpServletResponse + public static final String HTTP_SERVLET_RESPONSE = "httpServletResponse"; - public static final String CACHE_TYPE_REDIS = "redis"; - public static final String DEFAULT_FATE_ROOT = "FATE-SERVICES"; + +// public static final String PROPERTY_BIND_HOST_KEY = "bind.host"; + /** * configuration property key */ - public static final String PROPERTY_SELF_PARTY = "self.party"; - - public static final String PROPERTY_CACHE_TYPE = "cache.type"; - - public static final String PROPERTY_REDIS_EXPIRE = "redis.expire"; - public static final String PROPERTY_REDIS_CLUSTER_NODES = "redis.cluster.nodes"; - public static final String PROPERTY_LOCAL_CACHE_MAXSIZE = "local.cache.maxsize"; - public static final String PROPERTY_LOCAL_CACHE_EXPIRE = "local.cache.expire"; - public static final String PROPERTY_LOCAL_CACHE_INTERVAL = "local.cache.interval"; - - public static final String PROPERTY_GRPC_TIMEOUT = "grpc.timeout"; - public static final String PROPERTY_EXTERNAL_INFERENCE_RESULT_CACHE_DB_INDEX = "external.inferenceResultCacheDBIndex"; - public static final String PROPERTY_EXTERNAL_INFERENCE_RESULT_CACHE_TTL = "external.inferenceResultCacheTTL"; - public static final String PROPERTY_EXTERNAL_REMOTE_MODEL_INFERENCE_RESULT_CACHE_DB_INDEX = "external.remoteModelInferenceResultCacheDBIndex"; - public static final String PROPERTY_EXTERNAL_PROCESS_CACHE_DB_INDEX = "external.processCacheDBIndex"; - public static final String PROPERTY_EXTERNAL_REMOTE_MODEL_INFERENCE_RESULT_CACHE_TTL = "external.remoteModelInferenceResultCacheTTL"; - public static final String PROPERTY_CAN_CACHE_RET_CODE = "canCacheRetcode"; - public static final String PROPERTY_SERVICE_ROLE_NAME = "serviceRoleName"; - public static final String PROPERTY_SERVICE_ROLE_NAME_DEFAULT_VALUE = "serving"; - public static final String PROPERTY_ONLINE_DATA_ACCESS_ADAPTER = "OnlineDataAccessAdapter"; - public static final String PROPERTY_ONLINE_DATA_BATCH_ACCESS_ADAPTER = "OnlineDataBatchAccessAdapter"; - public static final String PROPERTY_MODEL_CACHE_ACCESS_TTL = "modelCacheAccessTTL"; - public static final String PROPERTY_MODEL_CACHE_MAX_SIZE = "modelCacheMaxSize"; - public static final String PROPERTY_INFERENCE_WORKER_THREAD_NUM = "inferenceWorkerThreadNum"; - public static final String PROPERTY_PROXY_ADDRESS = "proxy"; - public static final String ONLINE_ENVIRONMENT = "online"; - public static final String PROPERTY_ROLL_ADDRESS = "roll"; - public static final String PROPERTY_FLOW_ADDRESS = "flow"; - public static final String PROPERTY_SERVING_ADDRESS = "serving"; - public static final String PROPERTY_USE_ZOOKEEPER = "useZookeeper"; - public static final String PROPERTY_PORT = "port"; - public static final String PROPERTY_GRPC_PORT = "grpc.port"; - public static final String PROPERTY_GRPC_TLS_PORT = "grpc.tls.port"; - public static final String PROPERTY_USER_DIR = "user.dir"; - public static final String PROPERTY_USER_HOME = "user.home"; - public static final String PROPERTY_FILE_SEPARATOR = "file.separator"; - public static final String PROPERTY_ZK_URL = "zk.url"; - public static final String PROPERTY_USE_ZK_ROUTER = "useZkRouter"; - public static final String PROPERTY_USE_REGISTER = "useRegister"; - public static final String PROPERTY_MODEL_TRANSFER_URL = "model.transfer.url"; - public static final String PROPERTY_MODEL_SYNC = "model.synchronize"; - public static final String PROPERTY_TRANSFER_FILE_PATH = "transfer.file.path"; - - public static final String PROPERTY_FEATURE_BATCH_ADAPTOR = "feature.batch.adaptor"; - public static final String PROPERTY_ACL_ENABLE = "acl.enable"; - public static final String PROPERTY_ACL_USERNAME = "acl.username"; - public static final String PROPERTY_ACL_PASSWORD = "acl.password"; - public static final String PROXY_ROUTER_TABLE = "proxy.router.table"; - public static final String PROPERTY_BATCH_INFERENCE_MAX = "batch.inference.max"; + public static final String PROPERTY_SELF_PARTY_KEY = "self.party"; public static final String PROPERTY_PRINT_INPUT_DATA = "print.input.data"; public static final String PROPERTY_PRINT_OUTPUT_DATA = "print.output.data"; public static final String PROPERTY_NEGOTIATIONTYPE = "server.negotiationType"; @@ -120,68 +69,10 @@ public class Dict { public static final String CURRENT_VERSION = "currentVersion"; public static final String PROPERTY_COORDINATOR = "coordinator"; -// public static final String PROPERTY_SERVER_PORT = "server.port"; - - - public static final String PROPERTY_INFERENCE_SERVICE_NAME = "inference.service.name"; - public static final String PROPERTY_ROUTE_TYPE = "routeType"; - public static final String PROPERTY_ROUTE_TABLE = "route.table"; - public static final String PROPERTY_FLOW_RULE_TABLE = "flow.rule"; - public static final String PROPERTY_AUTH_FILE = "auth.file"; - public static final String PROPERTY_AUTH_OPEN = "auth.open"; - public static final String PROPERTY_PROXY_GRPC_INTRA_PORT = "proxy.grpc.intra.port"; - public static final String PROPERTY_PROXY_GRPC_INTER_PORT = "proxy.grpc.inter.port"; - public static final String PROPERTY_PROXY_GRPC_INFERENCE_TIMEOUT = "proxy.grpc.inference.timeout"; - public static final String PROPERTY_PROXY_GRPC_INFERENCE_ASYNC_TIMEOUT = "proxy.grpc.inference.async.timeout"; - public static final String PROPERTY_PROXY_GRPC_UNARYCALL_TIMEOUT = "proxy.grpc.unaryCall.timeout"; - public static final String PROPERTY_PROXY_GRPC_THREADPOOL_CORESIZE = "proxy.grpc.threadpool.coresize"; - public static final String PROPERTY_PROXY_GRPC_THREADPOOL_MAXSIZE = "proxy.grpc.threadpool.maxsize"; - public static final String PROPERTY_PROXY_GRPC_THREADPOOL_QUEUESIZE = "proxy.grpc.threadpool.queuesize"; - public static final String PROPERTY_PROXY_ASYNC_TIMEOUT = "proxy.async.timeout"; - public static final String PROPERTY_PROXY_ASYNC_CORESIZE = "proxy.async.coresize"; - public static final String PROPERTY_PROXY_ASYNC_MAXSIZE = "proxy.async.maxsize"; - public static final String PROPERTY_PROXY_GRPC_BATCH_INFERENCE_TIMEOUT = "proxy.grpc.batch.inference.timeout"; - public static final String PROPERTY_MODEL_CACHE_PATH = "model.cache.path"; - public static final String PROPERTY_LR_USE_PARALLEL = "lr.use.parallel"; - public static final String PROPERTY_ALLOW_HEALTH_CHECK = "health.check.allow"; - public static final String PROPERTY_TRANSFER_FILE_CACHE_SIZE = "transfer.file.cache.size"; - public static final String PROPERTY_MAX_TRANSFER_CACHE_SIZE = "max.transfer.cache.size"; - public static final String PROPERTY_USE_DIRECT_CACHE = "use.direct.cache"; - public static final String PROPERTY_GRPC_ONCOMPLETED_WAIT_TIMEOUT = "grpc.oncompleted.wait.timeout"; -// public static final String PROPERTY_USE_QUEUE_MODEL = "use.queue.model"; - public static final String PROPERTY_STREAM_LIMIT_MODE = "stream.limit.mode"; - public static final String PROPERTY_STREAM_LIMIT_MAX_TRY_TIME = "stream.limit.max.try.time"; - public static final String PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION = "grpc.server.max.concurrent.call.per.connection"; - public static final String PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE = "grpc.server.max.inbound.message.size"; - public static final String PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE = "grpc.server.max.inbound.metadata.size"; - public static final String PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW = "grpc.server.flow.control.window"; - public static final String PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC = "grpc.server.keepalive.time.sec"; - public static final String PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC = "grpc.server.keepalive.timeout.sec"; - public static final String PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC = "grpc.server.permit.keepalive.time.sec"; - public static final String PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED = "grpc.server.keepalive.without.calls.enabled"; - public static final String PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC = "grpc.server.max.connection.idle.sec"; - public static final String PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC = "grpc.server.max.connection.age.sec"; - public static final String PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC = "grpc.server.max.connection.age.grace.sec"; - public static final String PROPERTY_INTERVAL_MS = "interval.ms"; - public static final String PROPERTY_SAMPLE_COUNT = "sample.count"; - public static final String PRPPERTY_QUEUE_MAX_FREE_TIME = "queue.max.free.time"; - - public static String PROPERTY_OPEN_HTTP_SERVER = "open.http.server"; - public static String PROPERTY_OPEN_GRPC_TLS_SERVER = "open.grpc.tls.server"; - public static String PROPERTY_DEFAULT_CLIENT_VERSION="default.client.version"; - - - public static final String HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT = "httpclinet.config.connection.req.timeout"; - public static final String HTTP_CLIENT_CONFIG_CONN_TIME_OUT = "httpclient.config.connection.timeout"; - public static final String HTTP_CLIENT_CONFIG_SOCK_TIME_OUT = "httpclient.config.sockect.timeout"; - public static final String HTTP_CLIENT_INIT_POOL_MAX_TOTAL = "httpclient.init.pool.maxtotal"; - public static final String HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE = "httpclient.init.pool.def.max.pre.route"; - public static final String HTTP_CLIENT_INIT_POOL_SOCK_TIME_OUT = "httpclient.init.pool.sockect.timeout"; - public static final String HTTP_CLIENT_INIT_POOL_CONN_TIME_OUT = "httpclient.init.pool.connection.timeout"; - public static final String HTTP_CLIENT_INIT_POOL_CONN_REQ_TIME_OUT = "httpclient.init.pool.connection.req.timeout"; - public static final String HTTP_CLIENT_TRAN_CONN_REQ_TIME_OUT = "httpclient.tran.connection.req.timeout"; - public static final String HTTP_CLIENT_TRAN_CONN_TIME_OUT = "httpclient.tran.connection.timeout"; - public static final String HTTP_CLIENT_TRAN_SOCK_TIME_OUT = "httpclient.tran.sockect.timeout"; + + + + public static final String ACTION_TYPE_ASYNC_EXECUTE = "ASYNC_EXECUTE"; @@ -191,6 +82,8 @@ public class Dict { public static final String DATA = "data"; public static final String STATUS = "status"; public static final String SUCCESS = "success"; + public static final String DUP_MSG = "dup_msg"; + public static final String PROCESSED_MSG = "Processed messages"; public static final String PROB = "prob"; public static final String ACCESS = "access"; @@ -233,6 +126,8 @@ public class Dict { public static final String CASE_ID = "caseid"; public static final String CODE = "code"; public static final String MESSAGE = "message"; + public static final String MESSAGE_FLAG = "message_flag"; + public static final String MESSAGE_CODE = "message_code"; public static final String MODEL_ID = "modelId"; public static final String MODEL_VERSION = "modelVersion"; public static final String TIMESTAMP = "timestamp"; @@ -248,7 +143,10 @@ public class Dict { public static final String SELF_ENVIRONMENT = "online"; public static final String HEAD = "head"; public static final String BODY = "body"; - + public static final String SESSION_ID = "sessionId"; + public static final String METHOD_CONFIG_REQ_TIMEOUT = "reqTimeout"; + public static final String METHOD_CONFIG_CONNECTION_TIMEOUT = "connectionTimeout"; + public static final String METHOD_CONFIG_SOCKET_TIMEOUT = "socketTimeout"; public static final String SBT_TREE_NODE_ID_ARRAY = "sbtTreeNodeIdArray"; @@ -274,8 +172,6 @@ public class Dict { public static final String ERROR_LIST = "errorList"; public static final String HEALTH_INFO = "healthInfo"; public static final String PROPERTY_ADMIN_HEALTH_CHECK_TIME = "health.check.time"; - - public static final String ROLLSITE_ROUTE_TABLE_KEY = "rollsite.route.table.key"; public static final String ROLLSITE_ROUTE_TABLE_WHITE_LIST = "rollsite.route.table.whitList"; public static final String ROLLSITE_ROUTE_TABLE_PARTY_ID = "rollsite.route.table.party.id"; @@ -299,15 +195,16 @@ public class Dict { public static final String TOPIC = "topic"; - public static final String PROPERTY_DEPLOY_MODE = "deploy.model"; - public static final String PROPERTY_CLUSTER_MANAGER_ADDRESS = "cluster.manager.address"; + public static final String PROPERTY_DEPLOY_MODE_KEY = "deploy.model"; +// public static final String PROPERTY_CLUSTER_MANAGER_ADDRESS = "cluster.manager.address"; - public static final String PROPERTY_EGGROLL_CLUSTER_MANANGER_IP = "eggroll.cluster.manager.ip"; - public static final String PROPERTY_EGGROLL_CLUSTER_MANANGER_PORT = "eggroll.cluster.manager.port"; + public static final String PROPERTY_EGGROLL_CLUSTER_MANANGER_IP_KEY = "eggroll.cluster.manager.ip"; + public static final String PROPERTY_EGGROLL_CLUSTER_MANANGER_PORT_KEY = "eggroll.cluster.manager.port"; public final static String UNKNOWN = "UNKNOWN"; public final static String PROTOBUF = "PROTOBUF"; public final static String SLASH = "/"; + public final static String COMPONENTS_DIR = "components"; public final static String GRPC_PARSE_FROM = "parseFrom"; public final static String AT = "@"; public final static String AND = "&"; @@ -361,6 +258,7 @@ public class Dict { public final static String QUEUE = "queue"; public final static String TOTAL = "total"; public final static String LOCALHOST = "localhost"; + public final static String LOCALHOST2 = "127.0.0.1"; public final static String STORE_TYPE = "storeType"; public final static String STORE_TYPE_SNAKECASE = "store_type"; public final static String NAMESPACE = "namespace"; @@ -371,8 +269,21 @@ public class Dict { public final static String PARTITIONER = "partitioner"; public final static String SERDES = "serdes"; public final static String TRANSFER_BROKER_NAME = "transfer_broker_name"; - public static String PROPERTY_DLEDGER_PEER = "dledger.peer"; - public static String PROPERTY_DLEDGER_SELF = "dledger.self"; + public final static String TRANSFER_QUEUE = "transfer_queue"; + public final static String IS_CYCLE="cycle"; +// public final static String EGGROLL_SEND_TOPIC_PREFIX="EGGROLL_SEND_"; +// public final static String EGGROLL_BACK_TOPIC_PREFIX="EGGROLL_BACK_"; + public final static String STREAM_SEND_TOPIC_PREFIX = "STREAM_SEND_"; + public final static String STREAM_BACK_TOPIC_PREFIX = "STREAM_BACK_"; + public final static String BLOCKING_STUB = "BLOCKING_STUB"; + public final static String PROTOCOL = "protocol"; + public final static String URL="url"; + + public final static String USE_SSL="useSSL"; + public final static String CA_FILE="caFile"; + public final static String CERT_CHAIN_FILE="certChainFile"; + public final static String PRIVATE_KEY_FILE="privateKeyFile"; + } diff --git a/java/osx/core/src/main/java/com/osx/core/constant/EncryptMethod.java b/java/osx/osx-core/src/main/java/com/osx/core/constant/EncryptMethod.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/constant/EncryptMethod.java rename to java/osx/osx-core/src/main/java/com/osx/core/constant/EncryptMethod.java diff --git a/java/osx/core/src/main/java/com/osx/core/constant/NegotiationType.java b/java/osx/osx-core/src/main/java/com/osx/core/constant/NegotiationType.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/constant/NegotiationType.java rename to java/osx/osx-core/src/main/java/com/osx/core/constant/NegotiationType.java diff --git a/java/osx/core/src/main/java/com/osx/core/constant/PtpHttpHeader.java b/java/osx/osx-core/src/main/java/com/osx/core/constant/PtpHttpHeader.java similarity index 94% rename from java/osx/core/src/main/java/com/osx/core/constant/PtpHttpHeader.java rename to java/osx/osx-core/src/main/java/com/osx/core/constant/PtpHttpHeader.java index 756298010b..5d3d848023 100644 --- a/java/osx/core/src/main/java/com/osx/core/constant/PtpHttpHeader.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/constant/PtpHttpHeader.java @@ -51,10 +51,15 @@ public class PtpHttpHeader { static public final String TargetComponentName = "x-ptp-target-component-name"; static public final String TargetMethod = "x-ptp-target-method"; + static public final String SourceMethod = "x-ptp-source-method"; + static public final String MessageOffSet = "x-ptp-message-offset"; static public final String InstanceId = "x-ptp-instance-id"; static public final String Timestamp = "x-ptp-timestamp"; + + static public final String MessageFlag = "x-ptp-message-flag"; static public final String ReturnCode = "x-ptp-code"; static public final String ReturnMessage = "x-ptp-message"; + static public final String JobId = "x-ptp-job-id"; } diff --git a/java/osx/osx-core/src/main/java/com/osx/core/constant/Role.java b/java/osx/osx-core/src/main/java/com/osx/core/constant/Role.java new file mode 100644 index 0000000000..4d443dd637 --- /dev/null +++ b/java/osx/osx-core/src/main/java/com/osx/core/constant/Role.java @@ -0,0 +1,5 @@ +package com.osx.core.constant; + +public enum Role { + fateflow +} diff --git a/java/osx/core/src/main/java/com/osx/core/constant/StatusCode.java b/java/osx/osx-core/src/main/java/com/osx/core/constant/StatusCode.java similarity index 88% rename from java/osx/core/src/main/java/com/osx/core/constant/StatusCode.java rename to java/osx/osx-core/src/main/java/com/osx/core/constant/StatusCode.java index a72244d60b..ab0f07c30f 100644 --- a/java/osx/core/src/main/java/com/osx/core/constant/StatusCode.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/constant/StatusCode.java @@ -40,6 +40,12 @@ public class StatusCode { public static final String INVALID_REDIRECT_INFO = "142"; public static final String INVALID_INDEXFILE_DETAIL = "143"; public static final String CREATE_TOPIC_ERROR = "144"; + public static final String CYCLE_ROUTE_ERROR = "145"; + public static final String CONSUME_MSG_TIMEOUT= "146"; + + public static final String SESSION_INIT_ERROR= "147"; + public static final String TRANSFER_QUEUE_REDIRECT = "148"; + } diff --git a/java/osx/core/src/main/java/com/osx/core/constant/StreamLimitMode.java b/java/osx/osx-core/src/main/java/com/osx/core/constant/StreamLimitMode.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/constant/StreamLimitMode.java rename to java/osx/osx-core/src/main/java/com/osx/core/constant/StreamLimitMode.java diff --git a/java/osx/core/src/main/java/com/osx/core/constant/TransferStatus.java b/java/osx/osx-core/src/main/java/com/osx/core/constant/TransferStatus.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/constant/TransferStatus.java rename to java/osx/osx-core/src/main/java/com/osx/core/constant/TransferStatus.java diff --git a/java/osx/core/src/main/java/com/osx/core/context/Context.java b/java/osx/osx-core/src/main/java/com/osx/core/context/FateContext.java similarity index 57% rename from java/osx/core/src/main/java/com/osx/core/context/Context.java rename to java/osx/osx-core/src/main/java/com/osx/core/context/FateContext.java index 434761ece3..994c778f7f 100644 --- a/java/osx/core/src/main/java/com/osx/core/context/Context.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/context/FateContext.java @@ -16,35 +16,72 @@ package com.osx.core.context; import com.google.common.collect.Maps; import com.google.common.util.concurrent.ListenableFuture; +import com.osx.api.router.RouterInfo; +import com.osx.core.config.MetaInfo; import com.osx.core.constant.Dict; -import com.osx.core.router.RouterInfo; -import com.osx.core.utils.FlowLogPrinter; +import com.osx.api.constants.Protocol; + import com.osx.core.utils.FlowLogUtil; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.Map; -public class Context { - static final String LOGGER_NAME = "flow"; - private static final Logger logger = LoggerFactory.getLogger(LOGGER_NAME); +import com.osx.api.context.Context; +public class FateContext implements Context{ protected long timestamp = System.currentTimeMillis(); protected boolean needAssembleException = false; protected String actionType; protected String sessionId; + protected Protocol protocol; + protected String traceId; + protected String token; + protected String sourceInstId; + protected String desInstId; + protected String techProviderCode; protected boolean needPrintFlowLog = true; + protected boolean needCheckRouterInfo = true; protected Long dataSize; + + public Integer getRetryTime() { + return retryTime; + } + + public void setRetryTime(Integer retryTime) { + this.retryTime = retryTime; + } + + protected Integer retryTime =1; protected Map dataMap = Maps.newHashMap(); long costTime; String resourceName; - Throwable t; - FlowLogPrinter flowLogPrinter = FlowLogUtil::printFlowLog; + String messageFlag; + String messageCode; - public Context(){ + public String getJobId() { + return jobId; + } + + @Override + public void setJobId(String jobId) { + this.jobId = jobId; + } + + String jobId; + + Throwable t; + public FateContext(){ } - public Context(long timestamp, Map dataMap){ + public FateContext(long timestamp, Map dataMap){ timestamp = timestamp; this.dataMap = dataMap; } + + public boolean isDestination(){ + if(StringUtils.isNotEmpty(this.getDesPartyId())) + return MetaInfo.PROPERTY_SELF_PARTY.contains(this.getDesPartyId()); + else + return false; + } public Long getDataSize() { return dataSize; } @@ -56,53 +93,84 @@ public String getTopic() { return dataMap.get(Dict.TOPIC).toString(); return null; } + public String getTechProviderCode() { + return techProviderCode; + } + public void setTechProviderCode(String techProviderCode) { + this.techProviderCode = techProviderCode; + } + public Protocol getProtocol() { + return protocol; + } + public void setProtocol(Protocol protocol) { + this.protocol = protocol; + } + public String getMessageFlag() { + return messageFlag; + } + public String getMessageCode() { + return messageCode; + } + public void setMessageCode(String messageCode) { + this.messageCode = messageCode; + } + public void setMessageFlag(String messageFlag) { + this.messageFlag = messageFlag; + } public void setTopic(String topic) { this.dataMap.put(Dict.TOPIC, topic); } - public String getInstanceId() { return (String) dataMap.get(Dict.INSTANCE_ID); } - public void setInstanceId(String instanceId) { this.dataMap.put(Dict.INSTANCE_ID, instanceId); } - public Throwable getException() { return t; } - public void setException(Throwable t) { this.t = t; } - public String getSessionId() { return this.sessionId; } - public void setSessionId(String sessionId) { this.sessionId = sessionId; } - public String getActionType() { return actionType; } - public void setActionType(String actionType) { this.actionType = actionType; } - public Object getData(Object key) { return dataMap.get(key); } - public Object getDataOrDefault(Object key, Object defaultValue) { return dataMap.getOrDefault(key, defaultValue); } - public void putData(Object key, Object data) { dataMap.put(key, data); } + public String getTraceId() { + return traceId; + } + public void setTraceId(String traceId) { + this.traceId = traceId; + } + public String getToken() { + return token; + } + public void setToken(String token) { + this.token = token; + } + public boolean isNeedCheckRouterInfo() { + return needCheckRouterInfo; + } + public void setNeedCheckRouterInfo(boolean needCheckRouterInfo) { + this.needCheckRouterInfo = needCheckRouterInfo; + } public String getCaseId() { if (dataMap.get(Dict.CASEID) != null) { @@ -114,109 +182,89 @@ public String getCaseId() { public void setCaseId(String caseId) { dataMap.put(Dict.CASEID, caseId); } - public long getTimeStamp() { return timestamp; } - public Context subContext() { + public FateContext subContext() { Map newDataMap = Maps.newHashMap(dataMap); - return new Context(this.timestamp, newDataMap); + return new FateContext(this.timestamp, newDataMap); } - public boolean needPrintFlowLog() { return needPrintFlowLog; } - public void setNeedPrintFlowLog(boolean needPrintFlowLog) { this.needPrintFlowLog = needPrintFlowLog; } - public Long getRequestMsgIndex() { return (Long) this.dataMap.get(Dict.REQUEST_INDEX); } - public void setRequestMsgIndex(Long index) { this.dataMap.put(Dict.REQUEST_INDEX, index); } - public Long getCurrentMsgIndex() { return (Long) this.dataMap.get(Dict.CURRENT_INDEX); } - public void setCurrentMsgIndex(Long index) { this.dataMap.put(Dict.CURRENT_INDEX, index); } - public long getCostTime() { return costTime; } - public String getSrcPartyId() { return (String) dataMap.get(Dict.SOURCE_PARTY_ID); } - public void setSrcPartyId(String guestAppId) { dataMap.put(Dict.SOURCE_PARTY_ID, guestAppId); } - public String getDesPartyId() { return (String) dataMap.get(Dict.DES_PARTY_ID); } - public void setDesPartyId(String hostAppid) { dataMap.put(Dict.DES_PARTY_ID, hostAppid); } - public void setSrcComponent(String srcComponent){ dataMap.put(Dict.SOURCE_COMPONENT,srcComponent); } - public String getSrcComponent(){ return (String)dataMap.get(Dict.SOURCE_COMPONENT); } - public void setDesComponent(String desComponent){ dataMap.put(Dict.DES_COMPONENT,desComponent); } - public String getDesComponent(){ return (String)dataMap.get(Dict.DES_COMPONENT); } - public RouterInfo getRouterInfo() { return (RouterInfo) dataMap.get(Dict.ROUTER_INFO); } - public void setRouterInfo(RouterInfo routerInfo) { dataMap.put(Dict.ROUTER_INFO, routerInfo); } - public Object getResultData() { return dataMap.get(Dict.RESULT_DATA); } - public void setResultData(Object resultData) { dataMap.put(Dict.RESULT_DATA, resultData); } - public String getReturnCode() { return (String) dataMap.get(Dict.RETURN_CODE); } - public void setReturnCode(String returnCode) { dataMap.put(Dict.RETURN_CODE, returnCode); } - - public String getReturnMsg() { return (String) dataMap.get(Dict.RET_MSG); } - public void setReturnMsg(String returnMsg) { dataMap.put(Dict.RET_MSG, returnMsg); } - + public String getSelfPartyId(){ + return (String) dataMap.get(Dict.PROPERTY_SELF_PARTY_KEY); + } + public void setSelfPartyId(String partyId){ + dataMap.put(Dict.PROPERTY_SELF_PARTY_KEY,partyId); + } public long getDownstreamCost() { if (dataMap.get(Dict.DOWN_STREAM_COST) != null) { @@ -229,43 +277,33 @@ public long getDownstreamCost() { public void setDownstreamCost(long downstreamCost) { dataMap.put(Dict.DOWN_STREAM_COST, downstreamCost); } - public long getDownstreamBegin() { return dataMap.get(Dict.DOWN_STREAM_BEGIN) != null ? (long) dataMap.get(Dict.DOWN_STREAM_BEGIN) : 0; } - public void setDownstreamBegin(long downstreamBegin) { dataMap.put(Dict.DOWN_STREAM_BEGIN, downstreamBegin); } - public String getSourceIp() { return (String) dataMap.get(Dict.SOURCE_IP); } - public void setSourceIp(String sourceIp) { dataMap.put(Dict.SOURCE_IP, sourceIp); } - public String getServiceName() { return (String) dataMap.get(Dict.SERVICE_NAME); } - public void setServiceName(String serviceName) { dataMap.put(Dict.SERVICE_NAME, serviceName); } - public String getCallName() { return (String) dataMap.get(Dict.CALL_NAME); } - public void setCallName(String callName) { dataMap.put(Dict.CALL_NAME, callName); } - public void setRemoteFuture(ListenableFuture future) { this.dataMap.put(Dict.FUTURE, future); } - public String getResourceName() { if (StringUtils.isNotEmpty(resourceName)) { return resourceName; @@ -274,22 +312,72 @@ public String getResourceName() { } return resourceName; } - public boolean needAssembleException() { return needAssembleException; } + public String toString(){ + StringBuffer stringBuffer = new StringBuffer(); + if (this.getProtocol() != null) { + stringBuffer.append(this.getProtocol()).append(SPLIT); + } + if (this.getActionType() != null) { + stringBuffer.append(this.getActionType()).append(SPLIT); + } +// if(context.getSessionId()!=null){ +// stringBuffer.append("session:").append(context.getSessionId()).append(SPLIT); +// } + if (this.getTopic() != null) { + stringBuffer.append("topic:").append(this.getTopic()).append(SPLIT); + } - public FlowLogPrinter getFlowLogPrinter() { - return flowLogPrinter; - } - - public Context setFlowLogPrinter(FlowLogPrinter flowLogPrinter) { - this.flowLogPrinter = flowLogPrinter; - return this; - } - public void printFlowLog() { - if (needPrintFlowLog) { - flowLogPrinter.print(this); + if (this.getMessageFlag() != null) { + stringBuffer.append(this.getMessageFlag()).append(SPLIT); + } + if (this.getRequestMsgIndex() != null) { + stringBuffer.append("req-offset:").append(this.getRequestMsgIndex()).append(SPLIT); + } + if (this.getData(Dict.CURRENT_INDEX) != null) { + stringBuffer.append("offset-in-queue:").append(this.getData(Dict.CURRENT_INDEX)).append(SPLIT); } + if(StringUtils.isNotEmpty(this.messageCode)){ + stringBuffer.append("msg-code:").append(this.getMessageCode()).append(SPLIT); + } + if(this.jobId!=null){ + stringBuffer.append("job-id:").append(this.getJobId()).append(SPLIT); + } + if (this.getSrcPartyId() != null) { + stringBuffer.append("src:").append(this.getSrcPartyId()).append(SPLIT); + } + if (this.getDesPartyId() != null) { + stringBuffer.append("des:").append(this.getDesPartyId()).append(SPLIT); + } + if (this.getReturnCode() != null) { + stringBuffer.append("code:").append(this.getReturnCode()).append(SPLIT); + } + stringBuffer.append("cost:").append(System.currentTimeMillis() - this.getTimeStamp()).append(SPLIT); + if (this.getRouterInfo() != null) { + Protocol protocol = this.getRouterInfo().getProtocol(); + if (protocol != null) { + if (protocol.equals(Protocol.grpc)) { + stringBuffer.append(this.getRouterInfo().getHost() + ":" + this.getRouterInfo().getPort()).append(SPLIT); + } else if (protocol.equals(Protocol.http)) { + stringBuffer.append(this.getRouterInfo().getUrl()).append(SPLIT); + } + } + } + if (this.getDataSize() != null) { + stringBuffer.append("size:").append(this.getDataSize()).append(SPLIT); + } + if(this.retryTime>1){ + stringBuffer.append("retry:").append(this.retryTime).append(SPLIT); + } + if (this.getReturnMsg() != null) { + stringBuffer.append("msg:").append(this.getReturnMsg()); + } + + + return stringBuffer.toString(); } + static final String SPLIT= "|"; + } diff --git a/java/osx/core/src/main/java/com/osx/core/datasource/AbstractDataSource.java b/java/osx/osx-core/src/main/java/com/osx/core/datasource/AbstractDataSource.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/datasource/AbstractDataSource.java rename to java/osx/osx-core/src/main/java/com/osx/core/datasource/AbstractDataSource.java diff --git a/java/osx/core/src/main/java/com/osx/core/datasource/AutoRefreshDataSource.java b/java/osx/osx-core/src/main/java/com/osx/core/datasource/AutoRefreshDataSource.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/datasource/AutoRefreshDataSource.java rename to java/osx/osx-core/src/main/java/com/osx/core/datasource/AutoRefreshDataSource.java diff --git a/java/osx/core/src/main/java/com/osx/core/datasource/Converter.java b/java/osx/osx-core/src/main/java/com/osx/core/datasource/Converter.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/datasource/Converter.java rename to java/osx/osx-core/src/main/java/com/osx/core/datasource/Converter.java diff --git a/java/osx/core/src/main/java/com/osx/core/datasource/FileRefreshableDataSource.java b/java/osx/osx-core/src/main/java/com/osx/core/datasource/FileRefreshableDataSource.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/datasource/FileRefreshableDataSource.java rename to java/osx/osx-core/src/main/java/com/osx/core/datasource/FileRefreshableDataSource.java diff --git a/java/osx/core/src/main/java/com/osx/core/datasource/NamedThreadFactory.java b/java/osx/osx-core/src/main/java/com/osx/core/datasource/NamedThreadFactory.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/datasource/NamedThreadFactory.java rename to java/osx/osx-core/src/main/java/com/osx/core/datasource/NamedThreadFactory.java diff --git a/java/osx/core/src/main/java/com/osx/core/datasource/ReadableDataSource.java b/java/osx/osx-core/src/main/java/com/osx/core/datasource/ReadableDataSource.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/datasource/ReadableDataSource.java rename to java/osx/osx-core/src/main/java/com/osx/core/datasource/ReadableDataSource.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/AckIndexException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/AckIndexException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/AckIndexException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/AckIndexException.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/BaseException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/BaseException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/BaseException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/BaseException.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/ConfigErrorException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/ConfigErrorException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/ConfigErrorException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/ConfigErrorException.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/ConsumeNoMessageException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/ConsumeNoMessageException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/ConsumeNoMessageException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/ConsumeNoMessageException.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/ConsumerNotExistException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/ConsumerNotExistException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/ConsumerNotExistException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/ConsumerNotExistException.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/CreateTopicErrorException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/CreateTopicErrorException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/CreateTopicErrorException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/CreateTopicErrorException.java diff --git a/java/osx/osx-core/src/main/java/com/osx/core/exceptions/CycleRouteInfoException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/CycleRouteInfoException.java new file mode 100644 index 0000000000..2571bd4294 --- /dev/null +++ b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/CycleRouteInfoException.java @@ -0,0 +1,9 @@ +package com.osx.core.exceptions; + +import com.osx.core.constant.StatusCode; + +public class CycleRouteInfoException extends BaseException{ + public CycleRouteInfoException(String msg){ + super(StatusCode.CYCLE_ROUTE_ERROR,msg); + } +} diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/ErrorMessageUtil.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/ErrorMessageUtil.java similarity index 64% rename from java/osx/core/src/main/java/com/osx/core/exceptions/ErrorMessageUtil.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/ErrorMessageUtil.java index dcaed77a25..3394ec0ce9 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/ErrorMessageUtil.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/ErrorMessageUtil.java @@ -16,11 +16,13 @@ package com.osx.core.exceptions; +import com.osx.api.context.Context; import com.osx.core.constant.Dict; import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; + import io.grpc.Status; import io.grpc.StatusRuntimeException; +import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,6 +37,8 @@ public class ErrorMessageUtil { static Logger logger = LoggerFactory.getLogger(ErrorMessageUtil.class); + static final String MESSAGE_PREFIX = "PARTY_"; + public static String buildRemoteRpcErrorMsg(int code, String msg) { return new StringBuilder().append("host return code ").append(code) .append(" host msg :").append(msg).toString(); @@ -67,18 +71,25 @@ public static StatusRuntimeException throwableToException(Context context, Throw public static ExceptionInfo handleExceptionExceptionInfo(Context context, Throwable e) { ExceptionInfo exceptionInfo = new ExceptionInfo(); + String selfPartyId = context.getSelfPartyId(); + String oriMessage = e.getMessage(); + String message = ""; + if(StringUtils.isNotEmpty(selfPartyId)){ + message = MESSAGE_PREFIX+selfPartyId+":"+oriMessage; + }else{ + message = oriMessage; + } if (e instanceof BaseException) { BaseException baseException = (BaseException) e; exceptionInfo.setCode(baseException.getRetcode()); - exceptionInfo.setMessage(baseException.getMessage()); } else { exceptionInfo.setCode(StatusCode.SYSTEM_ERROR); - exceptionInfo.setMessage(e.getMessage()); } + exceptionInfo.setMessage(message); exceptionInfo.setThrowable(e); - if (context.needAssembleException()) { - exceptionInfo.setThrowable(throwableToException(context, e)); - } +// if (context.needAssembleException()) { +// exceptionInfo.setThrowable(throwableToException(context, e)); +// } return exceptionInfo; } @@ -95,35 +106,6 @@ public static Map handleExceptionToMap(Throwable e) { } public static Map handleException(Map result, Throwable e) { -// if (e instanceof IllegalArgumentException) { -// result.put(Dict.CODE, StatusCode.PARAM_ERROR); -// result.put(Dict.MESSAGE, "PARAM_ERROR"); -// } else if (e instanceof NoRouterInfoException) { -// result.put(Dict.CODE, StatusCode.GUEST_ROUTER_ERROR); -// result.put(Dict.MESSAGE, "ROUTER_ERROR"); -// } else if (e instanceof SysException) { -// result.put(Dict.CODE, StatusCode.SYSTEM_ERROR); -// result.put(Dict.MESSAGE, "SYSTEM_ERROR"); -// } else if (e instanceof OverLoadException) { -// result.put(Dict.CODE, StatusCode.OVER_LOAD_ERROR); -// result.put(Dict.MESSAGE, "OVER_LOAD"); -// } else if (e instanceof InvalidRoleInfoException) { -// result.put(Dict.CODE, StatusCode.INVALID_ROLE_ERROR); -// result.put(Dict.MESSAGE, "ROLE_ERROR"); -// } else if (e instanceof ShowDownRejectException) { -// result.put(Dict.CODE, StatusCode.SHUTDOWN_ERROR); -// result.put(Dict.MESSAGE, "SHUTDOWN_ERROR"); -// -// } else if (e instanceof NoResultException) { -// logger.error("NET_ERROR ", e); -// result.put(Dict.CODE, StatusCode.NET_ERROR); -// result.put(Dict.MESSAGE, "NET_ERROR"); -// } else { -// logger.error("SYSTEM_ERROR ", e); -// result.put(Dict.CODE, StatusCode.SYSTEM_ERROR); -// result.put(Dict.MESSAGE, "SYSTEM_ERROR"); -// } - return result; } } diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/ExceptionInfo.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/ExceptionInfo.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/ExceptionInfo.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/ExceptionInfo.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/InvalidRedirectInfoException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/InvalidRedirectInfoException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/InvalidRedirectInfoException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/InvalidRedirectInfoException.java diff --git a/java/osx/osx-core/src/main/java/com/osx/core/exceptions/InvalidRouteInfoException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/InvalidRouteInfoException.java new file mode 100644 index 0000000000..e6cd28c07b --- /dev/null +++ b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/InvalidRouteInfoException.java @@ -0,0 +1,8 @@ +package com.osx.core.exceptions; + +import com.osx.core.constant.StatusCode; + +public class InvalidRouteInfoException extends BaseException{ + + +} diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/MappedFileException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/MappedFileException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/MappedFileException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/MappedFileException.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/MessageParseException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/MessageParseException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/MessageParseException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/MessageParseException.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/NoRouterInfoException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/NoRouterInfoException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/NoRouterInfoException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/NoRouterInfoException.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/ParameterException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/ParameterException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/ParameterException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/ParameterException.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/ProduceMsgExcption.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/ProduceMsgExcption.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/ProduceMsgExcption.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/ProduceMsgExcption.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/PutMessageException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/PutMessageException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/PutMessageException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/PutMessageException.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/RemoteRpcException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/RemoteRpcException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/RemoteRpcException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/RemoteRpcException.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/RouterInfoOperateException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/RouterInfoOperateException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/RouterInfoOperateException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/RouterInfoOperateException.java diff --git a/java/osx/osx-core/src/main/java/com/osx/core/exceptions/SessionInitException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/SessionInitException.java new file mode 100644 index 0000000000..2d0f21d43f --- /dev/null +++ b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/SessionInitException.java @@ -0,0 +1,12 @@ +package com.osx.core.exceptions; + +import com.osx.core.constant.StatusCode; + +public class SessionInitException extends BaseException{ + public SessionInitException(String retCode, String message) { + super(retCode, message); + } + public SessionInitException(String message) { + super(StatusCode.SESSION_INIT_ERROR, message); + } +} diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/SysException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/SysException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/SysException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/SysException.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueAlreadyExistException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/TransferQueueAlreadyExistException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueAlreadyExistException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/TransferQueueAlreadyExistException.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueInvalidStatusException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/TransferQueueInvalidStatusException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueInvalidStatusException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/TransferQueueInvalidStatusException.java diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueNotExistException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/TransferQueueNotExistException.java similarity index 88% rename from java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueNotExistException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/TransferQueueNotExistException.java index 2956292fbb..04797705a7 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueNotExistException.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/TransferQueueNotExistException.java @@ -22,7 +22,7 @@ public TransferQueueNotExistException() { super(StatusCode.TRANSFER_QUEUE_NOT_FIND, "TRANSFER_QUEUE_NOT_FIND"); } - public TransferQueueNotExistException(String code, String msg) { - super(code, msg); + public TransferQueueNotExistException( String msg) { + super(StatusCode.TRANSFER_QUEUE_NOT_FIND, msg); } } diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/UnSupportMethodException.java b/java/osx/osx-core/src/main/java/com/osx/core/exceptions/UnSupportMethodException.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/exceptions/UnSupportMethodException.java rename to java/osx/osx-core/src/main/java/com/osx/core/exceptions/UnSupportMethodException.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/AbstractRule.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/AbstractRule.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/AbstractRule.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/AbstractRule.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/BucketLeapArray.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/BucketLeapArray.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/BucketLeapArray.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/BucketLeapArray.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowChecker.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterFlowChecker.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowChecker.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterFlowChecker.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowConfig.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterFlowConfig.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowConfig.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterFlowConfig.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowEvent.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterFlowEvent.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowEvent.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterFlowEvent.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowRuleManager.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterFlowRuleManager.java similarity index 91% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowRuleManager.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterFlowRuleManager.java index 9e202f33fc..5b5ffea3d9 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowRuleManager.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterFlowRuleManager.java @@ -14,6 +14,7 @@ * limitations under the License. */ package com.osx.core.flow; + import com.fasterxml.jackson.core.type.TypeReference; import com.osx.core.config.MetaInfo; import com.osx.core.datasource.FileRefreshableDataSource; @@ -29,8 +30,7 @@ import java.util.*; import java.util.concurrent.ConcurrentHashMap; -import static com.osx.core.config.MetaInfo.PROPERTY_INTERVAL_MS; -import static com.osx.core.config.MetaInfo.PROPERTY_SAMPLE_COUNT; +import static com.osx.core.config.MetaInfo.*; public final class ClusterFlowRuleManager { @@ -70,34 +70,23 @@ private static void initDefaultProperty() { PropertyListener> listener = new FlowRulePropertyListener(defaultNamespace); registerPropertyInternal(defaultNamespace, defaultProperty, listener); String currentPath = null; - if (MetaInfo.PROPERTY_FLOW_RULE_TABLE != null) { - currentPath = MetaInfo.PROPERTY_FLOW_RULE_TABLE; - } else { - URL url = Thread.currentThread().getContextClassLoader().getResource("flowRule.json"); - - if (url != null) { - currentPath = url.getPath(); - } else { - logger.error("file flowRule.json not found"); - } - } + //先考虑开发本地情况、不是本地再按服务器方式获取 + currentPath = PROPERTY_CONFIG_DIR + "/" + MetaInfo.PROPERTY_FLOW_RULE_TABLE; logger.info("load flow rule {}", currentPath); - if (currentPath != null) { - File confFile = new File(currentPath); - FileRefreshableDataSource fileRefreshableDataSource = null; - try { - fileRefreshableDataSource = new FileRefreshableDataSource(confFile, (source) -> { - - List content = JsonUtil.json2List((String) source, new TypeReference>() { - }); - logger.info("load flow rule content {}", content); - return content; + File confFile = new File(currentPath); + FileRefreshableDataSource fileRefreshableDataSource; + try { + fileRefreshableDataSource = new FileRefreshableDataSource(confFile, (source) -> { + + List content = JsonUtil.json2List((String) source, new TypeReference>() { }); - fileRefreshableDataSource.getProperty().addListener(listener); - } catch (FileNotFoundException e) { - e.printStackTrace(); - logger.error("flow rule file not exist"); - } + logger.info("load flow rule content {}", content); + return content; + }); + fileRefreshableDataSource.getProperty().addListener(listener); + } catch (FileNotFoundException e) { + e.printStackTrace(); + logger.error("flow rule file not exist"); } } @@ -177,6 +166,7 @@ private static void registerPropertyInternal(/*@NonNull*/ String namespace, /*@V resetNamespaceFlowIdMapFor(namespace); } } + public static void removeProperty(String namespace) { AssertUtil.notEmpty(namespace, "namespace cannot be empty"); synchronized (UPDATE_LOCK) { @@ -338,7 +328,6 @@ private static void applyClusterFlowRule(List list, /*@Valid*/ String Set flowIdSet = new HashSet<>(); for (FlowRule rule : list) { - System.err.println("===================" + rule); if (!rule.isClusterMode()) { continue; } @@ -368,7 +357,7 @@ private static void applyClusterFlowRule(List list, /*@Valid*/ String // Prepare cluster metric from valid flow ID. ClusterMetricStatistics.putMetricIfAbsent(rule.getResource(), - new ClusterMetric(PROPERTY_SAMPLE_COUNT, PROPERTY_INTERVAL_MS)); + new ClusterMetric(PROPERTY_FLOW_CONTROL_SAMPLE_COUNT, PROPERTY_FLOW_CONTROL_SAMPLE_INTERVAL)); } // Cleanup unused cluster metrics. diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterMetric.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterMetric.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterMetric.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterMetric.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterMetricBucket.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterMetricBucket.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterMetricBucket.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterMetricBucket.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterMetricLeapArray.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterMetricLeapArray.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterMetricLeapArray.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterMetricLeapArray.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterMetricStatistics.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterMetricStatistics.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterMetricStatistics.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterMetricStatistics.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterRuleConstant.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterRuleConstant.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterRuleConstant.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterRuleConstant.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterRuleUtil.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterRuleUtil.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterRuleUtil.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/ClusterRuleUtil.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/CurrentConcurrencyManager.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/CurrentConcurrencyManager.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/CurrentConcurrencyManager.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/CurrentConcurrencyManager.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/DebugSupport.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/DebugSupport.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/DebugSupport.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/DebugSupport.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/DynamicProperty.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/DynamicProperty.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/DynamicProperty.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/DynamicProperty.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/FileMetricReport.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/FileMetricReport.java similarity index 97% rename from java/osx/core/src/main/java/com/osx/core/flow/FileMetricReport.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/FileMetricReport.java index 37da21a12c..d2edd121ca 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/FileMetricReport.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/flow/FileMetricReport.java @@ -43,7 +43,7 @@ public void report(List data) { // logger.info("report {}",data); metricWriter.write(TimeUtil.currentTimeMillis(), data); } catch (Exception e) { - e.printStackTrace(); + // e.printStackTrace(); } } } diff --git a/java/osx/core/src/main/java/com/osx/core/flow/FlowCounter.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/FlowCounter.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/FlowCounter.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/FlowCounter.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/FlowCounterManager.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/FlowCounterManager.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/FlowCounterManager.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/FlowCounterManager.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/FlowRule.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/FlowRule.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/FlowRule.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/FlowRule.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/Function.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/Function.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/Function.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/Function.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/GlobalRequestLimiter.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/GlobalRequestLimiter.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/GlobalRequestLimiter.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/GlobalRequestLimiter.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/LeapArray.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/LeapArray.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/LeapArray.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/LeapArray.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/LimitQueue.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/LimitQueue.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/LimitQueue.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/LimitQueue.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/Metric.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/Metric.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/Metric.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/Metric.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/MetricBucket.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/MetricBucket.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/MetricBucket.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/MetricBucket.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/MetricEvent.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/MetricEvent.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/MetricEvent.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/MetricEvent.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/MetricNode.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/MetricNode.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/MetricNode.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/MetricNode.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/MetricReport.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/MetricReport.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/MetricReport.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/MetricReport.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/MetricSearcher.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/MetricSearcher.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/MetricSearcher.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/MetricSearcher.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/MetricWriter.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/MetricWriter.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/MetricWriter.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/MetricWriter.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/MetricsReader.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/MetricsReader.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/MetricsReader.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/MetricsReader.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/NamespaceFlowProperty.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/NamespaceFlowProperty.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/NamespaceFlowProperty.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/NamespaceFlowProperty.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/OccupySupport.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/OccupySupport.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/OccupySupport.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/OccupySupport.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/Property.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/Property.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/Property.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/Property.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/PropertyListener.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/PropertyListener.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/PropertyListener.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/PropertyListener.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/RequestLimiter.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/RequestLimiter.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/RequestLimiter.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/RequestLimiter.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/Rule.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/Rule.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/Rule.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/Rule.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/RuleConstant.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/RuleConstant.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/RuleConstant.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/RuleConstant.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/TimeUtil.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/TimeUtil.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/TimeUtil.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/TimeUtil.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/TokenService.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/TokenService.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/TokenService.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/TokenService.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/UnaryLeapArray.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/UnaryLeapArray.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/UnaryLeapArray.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/UnaryLeapArray.java diff --git a/java/osx/core/src/main/java/com/osx/core/flow/WindowWrap.java b/java/osx/osx-core/src/main/java/com/osx/core/flow/WindowWrap.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/flow/WindowWrap.java rename to java/osx/osx-core/src/main/java/com/osx/core/flow/WindowWrap.java diff --git a/java/osx/core/src/main/java/com/osx/core/frame/CountDownLatch.java b/java/osx/osx-core/src/main/java/com/osx/core/frame/CountDownLatch.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/frame/CountDownLatch.java rename to java/osx/osx-core/src/main/java/com/osx/core/frame/CountDownLatch.java diff --git a/java/osx/core/src/main/java/com/osx/core/frame/GrpcConnectionFactory.java b/java/osx/osx-core/src/main/java/com/osx/core/frame/GrpcConnectionFactory.java similarity index 84% rename from java/osx/core/src/main/java/com/osx/core/frame/GrpcConnectionFactory.java rename to java/osx/osx-core/src/main/java/com/osx/core/frame/GrpcConnectionFactory.java index c1b12a9a39..6c81bc2fe4 100644 --- a/java/osx/core/src/main/java/com/osx/core/frame/GrpcConnectionFactory.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/frame/GrpcConnectionFactory.java @@ -16,10 +16,10 @@ package com.osx.core.frame; +import com.osx.api.router.RouterInfo; import com.osx.core.config.GrpcChannelInfo; import com.osx.core.exceptions.NoRouterInfoException; import com.osx.core.exceptions.SysException; -import com.osx.core.router.RouterInfo; import io.grpc.ManagedChannel; import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; import io.grpc.netty.shaded.io.grpc.netty.NegotiationType; @@ -51,10 +51,14 @@ public static synchronized ManagedChannel createManagedChannel(RouterInfo router } if(usePooled) { if (managedChannelPool.get(routerInfo.toKey()) != null) { + ManagedChannel targetChannel = managedChannelPool.get(routerInfo.toKey()); + // logger.info("channel is shutdown : {} isTerminated {}",targetChannel.isShutdown() ,targetChannel.isTerminated() ,targetChannel.getState(true)); return managedChannelPool.get(routerInfo.toKey()); } else { ManagedChannel managedChannel = createManagedChannel(routerInfo, buildDefaultGrpcChannelInfo()); - managedChannelPool.put(routerInfo.toKey(), managedChannel); + if(managedChannel!=null) { + managedChannelPool.put(routerInfo.toKey(), managedChannel); + } return managedChannel; } }else{ @@ -96,25 +100,20 @@ public static synchronized ManagedChannel createManagedChannel(RouterInfo router .enableRetry() .retryBufferSize(channelInfo.getRetryBufferSize()) .maxRetryAttempts(channelInfo.getMaxRetryAttemps()); - - if (routerInfo != null && NegotiationType.TLS.name().equals(routerInfo.getNegotiationType()) - && StringUtils.isNotBlank(routerInfo.getCertChainFile()) - && StringUtils.isNotBlank(routerInfo.getPrivateKeyFile()) - && StringUtils.isNotBlank(routerInfo.getTrustCertCollectionFile())) { + if (routerInfo.isUseSSL() && NegotiationType.TLS.name().equals(routerInfo.getNegotiationType()) && StringUtils.isNotBlank(routerInfo.getCertChainFile()) && StringUtils.isNotBlank(routerInfo.getPrivateKeyFile()) && StringUtils.isNotBlank(routerInfo.getCaFile())) { SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient() .keyManager(new File(routerInfo.getCertChainFile()), new File(routerInfo.getPrivateKeyFile())) - .trustManager(new File(routerInfo.getTrustCertCollectionFile())) + .trustManager(new File(routerInfo.getCaFile())) .sessionTimeout(3600 << 4) .sessionCacheSize(65536); - channelBuilder.sslContext(sslContextBuilder.build()).useTransportSecurity(); - + channelBuilder.sslContext(sslContextBuilder.build()).useTransportSecurity().overrideAuthority(routerInfo.getHost()); } else { channelBuilder.usePlaintext(); } return channelBuilder.build(); } catch (Exception e) { - logger.error("create channel error : ", e); + logger.error("create channel to {} error : ",routerInfo, e); //e.printStackTrace(); } return null; diff --git a/java/osx/core/src/main/java/com/osx/core/frame/Lifecycle.java b/java/osx/osx-core/src/main/java/com/osx/core/frame/Lifecycle.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/frame/Lifecycle.java rename to java/osx/osx-core/src/main/java/com/osx/core/frame/Lifecycle.java diff --git a/java/osx/core/src/main/java/com/osx/core/frame/ServiceDataWrapper.java b/java/osx/osx-core/src/main/java/com/osx/core/frame/ServiceDataWrapper.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/frame/ServiceDataWrapper.java rename to java/osx/osx-core/src/main/java/com/osx/core/frame/ServiceDataWrapper.java diff --git a/java/osx/core/src/main/java/com/osx/core/frame/ServiceThread.java b/java/osx/osx-core/src/main/java/com/osx/core/frame/ServiceThread.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/frame/ServiceThread.java rename to java/osx/osx-core/src/main/java/com/osx/core/frame/ServiceThread.java diff --git a/java/osx/core/src/main/java/com/osx/core/jvm/JVMGCUtils.java b/java/osx/osx-core/src/main/java/com/osx/core/jvm/JVMGCUtils.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/jvm/JVMGCUtils.java rename to java/osx/osx-core/src/main/java/com/osx/core/jvm/JVMGCUtils.java diff --git a/java/osx/core/src/main/java/com/osx/core/jvm/JVMMemoryUtils.java b/java/osx/osx-core/src/main/java/com/osx/core/jvm/JVMMemoryUtils.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/jvm/JVMMemoryUtils.java rename to java/osx/osx-core/src/main/java/com/osx/core/jvm/JVMMemoryUtils.java diff --git a/java/osx/core/src/main/java/com/osx/core/jvm/JVMThreadUtils.java b/java/osx/osx-core/src/main/java/com/osx/core/jvm/JVMThreadUtils.java similarity index 78% rename from java/osx/core/src/main/java/com/osx/core/jvm/JVMThreadUtils.java rename to java/osx/osx-core/src/main/java/com/osx/core/jvm/JVMThreadUtils.java index ad8762dc6a..33cfc6b4a3 100644 --- a/java/osx/core/src/main/java/com/osx/core/jvm/JVMThreadUtils.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/jvm/JVMThreadUtils.java @@ -95,19 +95,6 @@ static public int getDeadLockedThreadCount() { } } - public static void main(String[] args) { - for (; ; ) { -// System.out.println("======================================================================="); -// System.out.println("getDaemonThreadCount: " + JVMThreadUtils.getDaemonThreadCount()); -// System.out.println("getNonHeapMemoryUsage: " + JVMThreadUtils.getThreadCount()); -// System.out.println("getPeakThreadCountAndReset: " + JVMThreadUtils.getAndResetPeakThreadCount()); -// System.out.println("getDeadLockedThreadCount: " + JVMThreadUtils.getDeadLockedThreadCount()); - try { - Thread.sleep(5000); - } catch (InterruptedException e) { - e.printStackTrace(); - } - } - } + } diff --git a/java/osx/core/src/main/java/com/osx/core/jvm/JvmInfo.java b/java/osx/osx-core/src/main/java/com/osx/core/jvm/JvmInfo.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/jvm/JvmInfo.java rename to java/osx/osx-core/src/main/java/com/osx/core/jvm/JvmInfo.java diff --git a/java/osx/core/src/main/java/com/osx/core/jvm/JvmInfoCounter.java b/java/osx/osx-core/src/main/java/com/osx/core/jvm/JvmInfoCounter.java similarity index 89% rename from java/osx/core/src/main/java/com/osx/core/jvm/JvmInfoCounter.java rename to java/osx/osx-core/src/main/java/com/osx/core/jvm/JvmInfoCounter.java index ebe7cef686..4c550f1543 100644 --- a/java/osx/core/src/main/java/com/osx/core/jvm/JvmInfoCounter.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/jvm/JvmInfoCounter.java @@ -67,16 +67,6 @@ public void run() { } } - public static void main(String[] args) { - JvmInfoCounter.start(); - while (true) { - System.err.println(JvmInfoCounter.getMemInfos()); - try { - Thread.sleep(1000); - } catch (InterruptedException e) { - e.printStackTrace(); - } - } - } + } diff --git a/java/osx/core/src/main/java/com/osx/core/jvm/JvmInfoLeapArray.java b/java/osx/osx-core/src/main/java/com/osx/core/jvm/JvmInfoLeapArray.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/jvm/JvmInfoLeapArray.java rename to java/osx/osx-core/src/main/java/com/osx/core/jvm/JvmInfoLeapArray.java diff --git a/java/osx/osx-core/src/main/java/com/osx/core/provider/TechProvider.java b/java/osx/osx-core/src/main/java/com/osx/core/provider/TechProvider.java new file mode 100644 index 0000000000..354d4640d7 --- /dev/null +++ b/java/osx/osx-core/src/main/java/com/osx/core/provider/TechProvider.java @@ -0,0 +1,53 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.osx.core.provider; + +import io.grpc.stub.StreamObserver; +import org.ppc.ptp.Osx; + + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +public interface TechProvider { + //用于处理http1.X请求 + void processHttpInvoke(HttpServletRequest httpServletRequest,HttpServletResponse httpServletResponse); + //用于处理grpc非流式请求 + void processGrpcInvoke(Osx.Inbound request, + io.grpc.stub.StreamObserver responseObserver); + +// rpc peek (PeekInbound) returns (TransportOutbound); +// rpc pop (PopInbound) returns (TransportOutbound); +// rpc push (PushInbound) returns (TransportOutbound); +// rpc release (ReleaseInbound) returns (TransportOutbound); + + //用于处理grpc流式请求 + public StreamObserver processGrpcTransport(Osx.Inbound inbound, StreamObserver responseObserver); + +// + void processGrpcPeek(Osx.PeekInbound inbound, io.grpc.stub.StreamObserver responseObserver); + + void processGrpcPush(Osx.PushInbound inbound, io.grpc.stub.StreamObserver responseObserver); + + void processGrpcPop(Osx.PopInbound inbound, io.grpc.stub.StreamObserver responseObserver); + + void processGrpcRelease(Osx.ReleaseInbound inbound, io.grpc.stub.StreamObserver responseObserver); + + + + + +} diff --git a/java/osx/osx-core/src/main/java/com/osx/core/ptp/SourceMethod.java b/java/osx/osx-core/src/main/java/com/osx/core/ptp/SourceMethod.java new file mode 100644 index 0000000000..87461bef66 --- /dev/null +++ b/java/osx/osx-core/src/main/java/com/osx/core/ptp/SourceMethod.java @@ -0,0 +1,6 @@ +package com.osx.core.ptp; + +public enum SourceMethod { + UNARY_CALL, OLDUNARY_CALL,PUSH + +} diff --git a/java/osx/core/src/main/java/com/osx/core/ptp/TargetMethod.java b/java/osx/osx-core/src/main/java/com/osx/core/ptp/TargetMethod.java similarity index 96% rename from java/osx/core/src/main/java/com/osx/core/ptp/TargetMethod.java rename to java/osx/osx-core/src/main/java/com/osx/core/ptp/TargetMethod.java index 56680b18af..5736cd77ee 100644 --- a/java/osx/core/src/main/java/com/osx/core/ptp/TargetMethod.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/ptp/TargetMethod.java @@ -25,7 +25,8 @@ public enum TargetMethod { CANCEL_TOPIC, PUSH, APPLY_TOKEN, - APPLY_TOPIC + APPLY_TOPIC, + TEST_STREAM diff --git a/java/osx/core/src/main/java/com/osx/core/queue/ClusterTransferQueueInfo.java b/java/osx/osx-core/src/main/java/com/osx/core/queue/ClusterTransferQueueInfo.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/queue/ClusterTransferQueueInfo.java rename to java/osx/osx-core/src/main/java/com/osx/core/queue/ClusterTransferQueueInfo.java diff --git a/java/osx/core/src/main/java/com/osx/core/queue/TranferQueueInfo.java b/java/osx/osx-core/src/main/java/com/osx/core/queue/TranferQueueInfo.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/queue/TranferQueueInfo.java rename to java/osx/osx-core/src/main/java/com/osx/core/queue/TranferQueueInfo.java diff --git a/java/osx/core/src/main/java/com/osx/core/service/AbstractServiceAdaptor.java b/java/osx/osx-core/src/main/java/com/osx/core/service/AbstractServiceAdaptor.java similarity index 70% rename from java/osx/core/src/main/java/com/osx/core/service/AbstractServiceAdaptor.java rename to java/osx/osx-core/src/main/java/com/osx/core/service/AbstractServiceAdaptor.java index 4a81e442e7..6f5ff68a21 100644 --- a/java/osx/core/src/main/java/com/osx/core/service/AbstractServiceAdaptor.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/service/AbstractServiceAdaptor.java @@ -18,12 +18,13 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.osx.api.context.Context; import com.osx.core.constant.Dict; import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; +import com.osx.core.context.FateContext; import com.osx.core.exceptions.ErrorMessageUtil; import com.osx.core.exceptions.ExceptionInfo; -import com.osx.core.utils.JsonUtil; +import com.osx.core.utils.FlowLogUtil; import io.grpc.stub.AbstractStub; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,17 +39,17 @@ * @Author **/ -public abstract class AbstractServiceAdaptor implements ServiceAdaptor { +public abstract class AbstractServiceAdaptor implements ServiceAdaptor { static public AtomicInteger requestInHandle = new AtomicInteger(0); public static boolean isOpen = true; - protected Logger flowLogger = LoggerFactory.getLogger("flow"); +// protected Logger flowLogger = LoggerFactory.getLogger("flow"); protected String serviceName; Logger logger = LoggerFactory.getLogger(this.getClass().getName()); - ServiceAdaptor serviceAdaptor; - InterceptorChain preChain = new DefaultInterceptorChain(); - InterceptorChain postChain = new DefaultInterceptorChain(); + ServiceAdaptor serviceAdaptor; + InterceptorChain preChain = new DefaultInterceptorChain<>(); + InterceptorChain postChain = new DefaultInterceptorChain<>(); private Map methodMap = Maps.newHashMap(); private AbstractStub serviceStub; @@ -69,7 +70,7 @@ public void setMethodMap(Map methodMap) { this.methodMap = methodMap; } - public AbstractServiceAdaptor addPreProcessor(Interceptor interceptor) { + public AbstractServiceAdaptor addPreProcessor(Interceptor interceptor) { preChain.addInterceptor(interceptor); return this; } @@ -78,7 +79,7 @@ public void addPostProcessor(Interceptor interceptor) { postChain.addInterceptor(interceptor); } - public ServiceAdaptor getServiceAdaptor() { + public ServiceAdaptor getServiceAdaptor() { return serviceAdaptor; } @@ -102,7 +103,7 @@ public void setServiceName(String serviceName) { this.serviceName = serviceName; } - protected abstract resp doService(Context context, InboundPackage data); + protected abstract resp doService(ctx context, InboundPackage data); /** * @param context @@ -111,7 +112,7 @@ public void setServiceName(String serviceName) { * @throws Exception */ @Override - public OutboundPackage service(Context context, InboundPackage data) throws RuntimeException { + public OutboundPackage service(ctx context, InboundPackage data) throws RuntimeException { OutboundPackage outboundPackage = new OutboundPackage(); // context.preProcess(); @@ -129,19 +130,14 @@ public OutboundPackage service(Context context, InboundPackage data) resp result = null; context.setServiceName(this.serviceName); try { - preChain.doPreProcess(context, data); + preChain.doProcess(context, data, outboundPackage); result = doService(context, data); - if (logger.isDebugEnabled()) { - logger.debug("do service, router info: {}, service name: {}, result: {}", JsonUtil.object2Json(context.getRouterInfo()), serviceName, result); - } } catch (Throwable e) { exceptions.add(e); e.printStackTrace(); logger.error("do service fail, cause by: {}", e.getMessage()); } outboundPackage.setData(result); - //postChain.doPostProcess(context, data, outboundPackage); - } catch (Throwable e) { exceptions.add(e); logger.error("service error", e); @@ -152,7 +148,16 @@ public OutboundPackage service(Context context, InboundPackage data) outboundPackage = this.serviceFail(context, data, exceptions); } } finally { - printFlowLog(context); + if(context instanceof FateContext ) + { + FateContext fateContext =(FateContext )context; + if(fateContext.needPrintFlowLog()){ + FlowLogUtil.printFlowLog(context); + } + }else { + + FlowLogUtil.printFlowLog(context); + } } // int returnCode = context.getReturnCode(); @@ -162,24 +167,20 @@ public OutboundPackage service(Context context, InboundPackage data) // context.postProcess(data, outboundPackage); } + try { + postChain.doProcess(context, data, outboundPackage); + } catch (Exception e) { + logger.error("service PostDoProcess error", e); + } return outboundPackage; } - protected void printFlowLog(Context context) { - - context.printFlowLog(); - -// flowLogger.info("{}|{}|{}|{}|" + -// "{}|{}|{}|{}|" + -// "{}|{}", -// context.getSourceIp(), context.getSrcPartyId(), -// context.getDesPartyId(), context.getReturnCode(), context.getCostTime(), -// context.getDownstreamCost(), serviceName, context.getRouterInfo() != null ? context.getRouterInfo() : "", -// MetaInfo.PROPERTY_PRINT_INPUT_DATA?context.getData(Dict.INPUT_DATA):"", -// MetaInfo.PROPERTY_PRINT_OUTPUT_DATA?context.getData(Dict.OUTPUT_DATA):""); - } +// protected void printFlowLog(ctx context) { +//// context.printFlowLog(); +// FlowLogUtil.printFlowLog(context); +// } - protected OutboundPackage serviceFailInner(Context context, InboundPackage data, Throwable e) { + protected OutboundPackage serviceFailInner(ctx context, InboundPackage data, Throwable e) { OutboundPackage outboundPackage = new OutboundPackage(); ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context, e); context.setReturnCode(exceptionInfo.getCode()); @@ -191,13 +192,13 @@ protected OutboundPackage serviceFailInner(Context context, InboundPackage } @Override - public OutboundPackage serviceFail(Context context, InboundPackage data, List errors) throws RuntimeException { + public OutboundPackage serviceFail(ctx context, InboundPackage data, List errors) throws RuntimeException { Throwable e = errors.get(0); logger.error("service fail ", e); return serviceFailInner(context, data, e); } - protected abstract resp transformExceptionInfo(Context context, ExceptionInfo exceptionInfo); + protected abstract resp transformExceptionInfo(ctx context, ExceptionInfo exceptionInfo); } \ No newline at end of file diff --git a/java/osx/core/src/main/java/com/osx/core/service/DefaultInterceptorChain.java b/java/osx/osx-core/src/main/java/com/osx/core/service/DefaultInterceptorChain.java similarity index 65% rename from java/osx/core/src/main/java/com/osx/core/service/DefaultInterceptorChain.java rename to java/osx/osx-core/src/main/java/com/osx/core/service/DefaultInterceptorChain.java index adfd18a0fd..84d0453a67 100644 --- a/java/osx/core/src/main/java/com/osx/core/service/DefaultInterceptorChain.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/service/DefaultInterceptorChain.java @@ -17,7 +17,7 @@ package com.osx.core.service; import com.google.common.collect.Lists; -import com.osx.core.context.Context; +import com.osx.api.context.Context; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -27,14 +27,14 @@ * @Description TODO * @Author **/ -public class DefaultInterceptorChain implements InterceptorChain { +public class DefaultInterceptorChain implements InterceptorChain { Logger logger = LoggerFactory.getLogger(DefaultInterceptorChain.class); - List> chain = Lists.newArrayList(); + List> chain = Lists.newArrayList(); @Override - public void addInterceptor(Interceptor interceptor) { + public void addInterceptor(Interceptor interceptor) { chain.add(interceptor); } @@ -46,12 +46,11 @@ public void addInterceptor(Interceptor interceptor) { * @throws Exception */ @Override - public void doPreProcess(Context context, InboundPackage inboundPackage) throws Exception { - for (Interceptor interceptor : chain) { - logger.info("====== {}",interceptor); - interceptor.doPreProcess(context, inboundPackage); - + public void doProcess(ctx context, InboundPackage inboundPackage,OutboundPackage outboundPackage) throws Exception { + for (Interceptor interceptor : chain) { + if (interceptor != null) { + interceptor.doProcess(context, inboundPackage,outboundPackage); + } } } - } diff --git a/java/osx/core/src/main/java/com/osx/core/service/InboundPackage.java b/java/osx/osx-core/src/main/java/com/osx/core/service/InboundPackage.java similarity index 86% rename from java/osx/core/src/main/java/com/osx/core/service/InboundPackage.java rename to java/osx/osx-core/src/main/java/com/osx/core/service/InboundPackage.java index 652b47bd13..3798dc158b 100644 --- a/java/osx/core/src/main/java/com/osx/core/service/InboundPackage.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/service/InboundPackage.java @@ -16,7 +16,7 @@ package com.osx.core.service; -import com.osx.core.router.RouterInfo; + import io.grpc.ManagedChannel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,7 +28,7 @@ public class InboundPackage { static Logger logger = LoggerFactory.getLogger(InboundPackage.class); ManagedChannel managedChannel; - RouterInfo routerInfo; + String source; Map head; @@ -42,14 +42,6 @@ public void setManagedChannel(ManagedChannel managedChannel) { this.managedChannel = managedChannel; } - public RouterInfo getRouterInfo() { - return routerInfo; - } - - public void setRouterInfo(RouterInfo routerInfo) { - this.routerInfo = routerInfo; - } - public String getSource() { return source; } diff --git a/java/osx/core/src/main/java/com/osx/core/service/Interceptor.java b/java/osx/osx-core/src/main/java/com/osx/core/service/Interceptor.java similarity index 73% rename from java/osx/core/src/main/java/com/osx/core/service/Interceptor.java rename to java/osx/osx-core/src/main/java/com/osx/core/service/Interceptor.java index 611c811ea4..a3e989f278 100644 --- a/java/osx/core/src/main/java/com/osx/core/service/Interceptor.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/service/Interceptor.java @@ -17,11 +17,12 @@ package com.osx.core.service; -import com.osx.core.context.Context; +import com.osx.api.context.Context; -public interface Interceptor { +public interface Interceptor { - default public void doPreProcess(Context context, InboundPackage inboundPackage) throws Exception { + default public void doProcess(ctx context, InboundPackage inboundPackage, OutboundPackage outboundPackage) throws Exception { } + } diff --git a/java/osx/core/src/main/java/com/osx/core/service/InterceptorChain.java b/java/osx/osx-core/src/main/java/com/osx/core/service/InterceptorChain.java similarity index 75% rename from java/osx/core/src/main/java/com/osx/core/service/InterceptorChain.java rename to java/osx/osx-core/src/main/java/com/osx/core/service/InterceptorChain.java index 627058b569..3c84a1bdcb 100644 --- a/java/osx/core/src/main/java/com/osx/core/service/InterceptorChain.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/service/InterceptorChain.java @@ -17,8 +17,10 @@ package com.osx.core.service; -public interface InterceptorChain extends Interceptor { +import com.osx.api.context.Context; - public void addInterceptor(Interceptor interceptor); +public interface InterceptorChain extends Interceptor { + + public void addInterceptor(Interceptor interceptor); } diff --git a/java/osx/core/src/main/java/com/osx/core/service/OutboundPackage.java b/java/osx/osx-core/src/main/java/com/osx/core/service/OutboundPackage.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/service/OutboundPackage.java rename to java/osx/osx-core/src/main/java/com/osx/core/service/OutboundPackage.java diff --git a/java/osx/core/src/main/java/com/osx/core/service/ServiceAdaptor.java b/java/osx/osx-core/src/main/java/com/osx/core/service/ServiceAdaptor.java similarity index 70% rename from java/osx/core/src/main/java/com/osx/core/service/ServiceAdaptor.java rename to java/osx/osx-core/src/main/java/com/osx/core/service/ServiceAdaptor.java index d526f4c468..a0479e544a 100644 --- a/java/osx/core/src/main/java/com/osx/core/service/ServiceAdaptor.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/service/ServiceAdaptor.java @@ -15,16 +15,14 @@ */ package com.osx.core.service; - - -import com.osx.core.context.Context; +import com.osx.api.context.Context; import java.util.List; -public interface ServiceAdaptor { +public interface ServiceAdaptor { - public OutboundPackage service(Context context, InboundPackage inboundPackage); + public OutboundPackage service(ctx context, InboundPackage inboundPackage); - public OutboundPackage serviceFail(Context context, InboundPackage data, List e); + public OutboundPackage serviceFail(ctx context, InboundPackage data, List e); } diff --git a/java/osx/core/src/main/java/com/osx/core/timer/HashedWheelTimer.java b/java/osx/osx-core/src/main/java/com/osx/core/timer/HashedWheelTimer.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/timer/HashedWheelTimer.java rename to java/osx/osx-core/src/main/java/com/osx/core/timer/HashedWheelTimer.java diff --git a/java/osx/core/src/main/java/com/osx/core/timer/Timeout.java b/java/osx/osx-core/src/main/java/com/osx/core/timer/Timeout.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/timer/Timeout.java rename to java/osx/osx-core/src/main/java/com/osx/core/timer/Timeout.java diff --git a/java/osx/core/src/main/java/com/osx/core/timer/Timer.java b/java/osx/osx-core/src/main/java/com/osx/core/timer/Timer.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/timer/Timer.java rename to java/osx/osx-core/src/main/java/com/osx/core/timer/Timer.java diff --git a/java/osx/core/src/main/java/com/osx/core/timer/TimerTask.java b/java/osx/osx-core/src/main/java/com/osx/core/timer/TimerTask.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/timer/TimerTask.java rename to java/osx/osx-core/src/main/java/com/osx/core/timer/TimerTask.java diff --git a/java/osx/core/src/main/java/com/osx/core/token/TokenRequest.java b/java/osx/osx-core/src/main/java/com/osx/core/token/TokenRequest.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/token/TokenRequest.java rename to java/osx/osx-core/src/main/java/com/osx/core/token/TokenRequest.java diff --git a/java/osx/core/src/main/java/com/osx/core/token/TokenResult.java b/java/osx/osx-core/src/main/java/com/osx/core/token/TokenResult.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/token/TokenResult.java rename to java/osx/osx-core/src/main/java/com/osx/core/token/TokenResult.java diff --git a/java/osx/core/src/main/java/com/osx/core/token/TokenResultStatus.java b/java/osx/osx-core/src/main/java/com/osx/core/token/TokenResultStatus.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/token/TokenResultStatus.java rename to java/osx/osx-core/src/main/java/com/osx/core/token/TokenResultStatus.java diff --git a/java/osx/core/src/main/java/com/osx/core/utils/AssertUtil.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/AssertUtil.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/utils/AssertUtil.java rename to java/osx/osx-core/src/main/java/com/osx/core/utils/AssertUtil.java diff --git a/java/osx/osx-core/src/main/java/com/osx/core/utils/ClassUtils.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/ClassUtils.java new file mode 100644 index 0000000000..dbdde882b1 --- /dev/null +++ b/java/osx/osx-core/src/main/java/com/osx/core/utils/ClassUtils.java @@ -0,0 +1,392 @@ + +package com.osx.core.utils; + + + +import org.apache.commons.lang3.StringUtils; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.lang.reflect.Array; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + + +public class ClassUtils { + + public static final String CLASS_EXTENSION = ".class"; + + public static final String JAVA_EXTENSION = ".java"; + private static final int JIT_LIMIT = 5 * 1024; + + private ClassUtils() { + } + + public static Object newInstance(String name) { + try { + return forName(name).getDeclaredConstructor().newInstance(); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { + throw new IllegalStateException(e.getMessage(), e); + } + } + + public static Class forName(String[] packages, String className) { + try { + return classForName(className); + } catch (ClassNotFoundException e) { + if (packages != null && packages.length > 0) { + for (String pkg : packages) { + try { + return classForName(pkg + "." + className); + } catch (ClassNotFoundException ignore) { + } + } + } + throw new IllegalStateException(e.getMessage(), e); + } + } + + public static Class forName(String className) { + try { + return classForName(className); + } catch (ClassNotFoundException e) { + throw new IllegalStateException(e.getMessage(), e); + } + } + + public static Class classForName(String className) throws ClassNotFoundException { + switch (className) { + case "boolean": + return boolean.class; + case "byte": + return byte.class; + case "char": + return char.class; + case "short": + return short.class; + case "int": + return int.class; + case "long": + return long.class; + case "float": + return float.class; + case "double": + return double.class; + case "boolean[]": + return boolean[].class; + case "byte[]": + return byte[].class; + case "char[]": + return char[].class; + case "short[]": + return short[].class; + case "int[]": + return int[].class; + case "long[]": + return long[].class; + case "float[]": + return float[].class; + case "double[]": + return double[].class; + default: + } + try { + return arrayForName(className); + } catch (ClassNotFoundException e) { + // try to load from java.lang package + if (className.indexOf('.') == -1) { + try { + return arrayForName("java.lang." + className); + } catch (ClassNotFoundException ignore) { + // ignore, let the original exception be thrown + } + } + throw e; + } + } + + private static Class arrayForName(String className) throws ClassNotFoundException { + return Class.forName(className.endsWith("[]") + ? "[L" + className.substring(0, className.length() - 2) + ";" + : className, true, Thread.currentThread().getContextClassLoader()); + } + + public static Class getBoxedClass(Class type) { + if (type == boolean.class) { + return Boolean.class; + } else if (type == char.class) { + return Character.class; + } else if (type == byte.class) { + return Byte.class; + } else if (type == short.class) { + return Short.class; + } else if (type == int.class) { + return Integer.class; + } else if (type == long.class) { + return Long.class; + } else if (type == float.class) { + return Float.class; + } else if (type == double.class) { + return Double.class; + } else { + return type; + } + } + + public static Boolean boxed(boolean v) { + return Boolean.valueOf(v); + } + + public static Character boxed(char v) { + return Character.valueOf(v); + } + + public static Byte boxed(byte v) { + return Byte.valueOf(v); + } + + public static Short boxed(short v) { + return Short.valueOf(v); + } + + public static Integer boxed(int v) { + return Integer.valueOf(v); + } + + public static Long boxed(long v) { + return Long.valueOf(v); + } + + public static Float boxed(float v) { + return Float.valueOf(v); + } + + public static Double boxed(double v) { + return Double.valueOf(v); + } + + public static Object boxed(Object v) { + return v; + } + + public static boolean unboxed(Boolean v) { + return v == null ? false : v.booleanValue(); + } + + public static char unboxed(Character v) { + return v == null ? '\0' : v.charValue(); + } + + public static byte unboxed(Byte v) { + return v == null ? 0 : v.byteValue(); + } + + public static short unboxed(Short v) { + return v == null ? 0 : v.shortValue(); + } + + public static int unboxed(Integer v) { + return v == null ? 0 : v.intValue(); + } + + public static long unboxed(Long v) { + return v == null ? 0 : v.longValue(); + } + + public static float unboxed(Float v) { + return v == null ? 0 : v.floatValue(); + } + + public static double unboxed(Double v) { + return v == null ? 0 : v.doubleValue(); + } + + public static Object unboxed(Object v) { + return v; + } + + public static boolean isNotEmpty(Object object) { + return getSize(object) > 0; + } + + public static int getSize(Object object) { + if (object == null) { + return 0; + } + if (object instanceof Collection) { + return ((Collection) object).size(); + } else if (object instanceof Map) { + return ((Map) object).size(); + } else if (object.getClass().isArray()) { + return Array.getLength(object); + } else { + return -1; + } + } + + public static URI toURI(String name) { + try { + return new URI(name); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public static boolean isBeforeJava5(String javaVersion) { + return (StringUtils.isEmpty(javaVersion) || "1.0".equals(javaVersion) + || "1.1".equals(javaVersion) || "1.2".equals(javaVersion) + || "1.3".equals(javaVersion) || "1.4".equals(javaVersion)); + } + + public static boolean isBeforeJava6(String javaVersion) { + return isBeforeJava5(javaVersion) || "1.5".equals(javaVersion); + } + + public static String toString(Throwable e) { + StringWriter w = new StringWriter(); + PrintWriter p = new PrintWriter(w); + p.print(e.getClass().getName() + ": "); + if (e.getMessage() != null) { + p.print(e.getMessage() + "\n"); + } + p.println(); + try { + e.printStackTrace(p); + return w.toString(); + } finally { + p.close(); + } + } + + public static void checkBytecode(String name, byte[] bytecode) { + if (bytecode.length > JIT_LIMIT) { + System.err.println("The template bytecode too long, may be affect the JIT compiler. template class: " + name); + } + } + + public static String getSizeMethod(Class cls) { + try { + return cls.getMethod("size", new Class[0]).getName() + "()"; + } catch (NoSuchMethodException e) { + try { + return cls.getMethod("length", new Class[0]).getName() + "()"; + } catch (NoSuchMethodException e2) { + try { + return cls.getMethod("getSize", new Class[0]).getName() + "()"; + } catch (NoSuchMethodException e3) { + try { + return cls.getMethod("getLength", new Class[0]).getName() + "()"; + } catch (NoSuchMethodException e4) { + return null; + } + } + } + } + } + + public static String getMethodName(Method method, Class[] parameterClasses, String rightCode) { + StringBuilder buf = new StringBuilder(rightCode); + if (method.getParameterTypes().length > parameterClasses.length) { + Class[] types = method.getParameterTypes(); + for (int i = parameterClasses.length; i < types.length; i++) { + if (buf.length() > 0) { + buf.append(','); + } + Class type = types[i]; + String def; + if (type == boolean.class) { + def = "false"; + } else if (type == char.class) { + def = "\'\\0\'"; + } else if (type == byte.class + || type == short.class + || type == int.class + || type == long.class + || type == float.class + || type == double.class) { + def = "0"; + } else { + def = "null"; + } + buf.append(def); + } + } + return method.getName() + "(" + buf + ")"; + } + + public static Method searchMethod(Class currentClass, String name, Class[] parameterTypes) throws NoSuchMethodException { + if (currentClass == null) { + throw new NoSuchMethodException("class == null"); + } + try { + return currentClass.getMethod(name, parameterTypes); + } catch (NoSuchMethodException e) { + for (Method method : currentClass.getMethods()) { + if (method.getName().equals(name) + && parameterTypes.length == method.getParameterTypes().length + && Modifier.isPublic(method.getModifiers())) { + if (parameterTypes.length > 0) { + Class[] types = method.getParameterTypes(); + boolean match = true; + for (int i = 0; i < parameterTypes.length; i++) { + if (!types[i].isAssignableFrom(parameterTypes[i])) { + match = false; + break; + } + } + if (!match) { + continue; + } + } + return method; + } + } + throw e; + } + } + + public static String getInitCode(Class type) { + if (byte.class.equals(type) + || short.class.equals(type) + || int.class.equals(type) + || long.class.equals(type) + || float.class.equals(type) + || double.class.equals(type)) { + return "0"; + } else if (char.class.equals(type)) { + return "'\\0'"; + } else if (boolean.class.equals(type)) { + return "false"; + } else { + return "null"; + } + } + + public static Map toMap(Map.Entry[] entries) { + Map map = new HashMap(); + if (entries != null && entries.length > 0) { + for (Map.Entry entry : entries) { + map.put(entry.getKey(), entry.getValue()); + } + } + return map; + } + + /** + * get simple class name from qualified class name + */ + public static String getSimpleClassName(String qualifiedName) { + if (null == qualifiedName) { + return null; + } + int i = qualifiedName.lastIndexOf('.'); + return i < 0 ? qualifiedName : qualifiedName.substring(i + 1); + } + +} diff --git a/java/osx/core/src/main/java/com/osx/core/utils/EncryptUtils.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/EncryptUtils.java similarity index 97% rename from java/osx/core/src/main/java/com/osx/core/utils/EncryptUtils.java rename to java/osx/osx-core/src/main/java/com/osx/core/utils/EncryptUtils.java index 7c676ad2fc..a2a4bd5510 100644 --- a/java/osx/core/src/main/java/com/osx/core/utils/EncryptUtils.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/utils/EncryptUtils.java @@ -38,8 +38,8 @@ public static String encrypt(String originString, EncryptMethod encryptMethod) { result += Integer.toHexString((0x000000FF & s[i]) | 0xFFFFFF00).substring(6); } return result; - } catch (Exception e) { - e.printStackTrace(); + } catch (Exception igore) { + } return ""; diff --git a/java/osx/osx-core/src/main/java/com/osx/core/utils/FileUtils.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/FileUtils.java new file mode 100644 index 0000000000..ae8aef0d09 --- /dev/null +++ b/java/osx/osx-core/src/main/java/com/osx/core/utils/FileUtils.java @@ -0,0 +1,137 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.osx.core.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.*; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.FileLock; + +public class FileUtils { + private static final Logger logger = LoggerFactory.getLogger(FileUtils.class); + + public static boolean writeFile(String context, File target) { + BufferedWriter out = null; + try { + if (!target.exists()) { + target.createNewFile(); + } + out = new BufferedWriter(new FileWriter(target)); + out.write(context); + } catch (IOException e) { + logger.error(e.getMessage()); + return false; + } finally { + try { + if (out != null) { + out.flush(); + out.close(); + } + } catch (IOException ex) { + logger.error("write file error", ex); + } + } + return true; + } + + /** + * Write string to file, + * synchronize operation, exclusive lock + */ + public static boolean writeStr2ReplaceFileSync(String str, String pathFile) throws Exception { + File file = new File(pathFile); + try { + if (!file.exists()) { + file.createNewFile(); + } + } catch (IOException e) { + logger.error("Failed to create the file. Check whether the path is valid and the read/write permission is correct"); + throw new IOException("Failed to create the file. Check whether the path is valid and the read/write permission is correct"); + } + FileOutputStream fileOutputStream = null; + FileChannel fileChannel = null; + FileLock fileLock; + try { + + /* + * write file + */ + fileOutputStream = new FileOutputStream(file); + fileChannel = fileOutputStream.getChannel(); + + try { + fileLock = fileChannel.tryLock();// exclusive lock + } catch (Exception e) { + throw new IOException("another thread is writing ,refresh and try again"); + } + if (fileLock != null) { + fileChannel.write(ByteBuffer.wrap(str.getBytes())); + if (fileLock.isValid()) { + fileLock.release(); // release-write-lock + } + if (file.length() != str.getBytes().length) { + throw new IOException("write successfully but the content was lost, reedit and try again"); + } + } + + } catch (IOException e) { + logger.error(e.getMessage()); + throw new IOException(e.getMessage()); + } finally { + close(fileChannel); + close(fileOutputStream); + } + return true; + } + + public static void close(Closeable closeable) { + if (closeable != null) { + try { + closeable.close(); + } catch (IOException e) { + e.printStackTrace(); + } + } + } + + public static boolean createNewFile(String filePath) { + return createNewFile(new File(filePath)); + } + + public static boolean createNewFile(File file) { + try { + if (!file.exists()) { + if (!file.getParentFile().exists()) { + if (!file.getParentFile().mkdirs()) { + return false; + } + } + if (!file.createNewFile()) { + return false; + } + } + } catch (IOException e) { + logger.error("create file failed , path = {}", file.getAbsoluteFile()); + return false; + } + return true; + } + +} diff --git a/java/osx/core/src/main/java/com/osx/core/utils/FlowLogPrinter.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/FlowLogPrinter.java similarity index 81% rename from java/osx/core/src/main/java/com/osx/core/utils/FlowLogPrinter.java rename to java/osx/osx-core/src/main/java/com/osx/core/utils/FlowLogPrinter.java index 897700de63..fb5068d70d 100644 --- a/java/osx/core/src/main/java/com/osx/core/utils/FlowLogPrinter.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/utils/FlowLogPrinter.java @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.utils; - -import com.osx.core.context.Context; - -public interface FlowLogPrinter { - - public void print(Context context); -} +//package com.osx.core.utils; +// +// +// +//public interface FlowLogPrinter { +// +// public void print(Context context); +//} diff --git a/java/osx/osx-core/src/main/java/com/osx/core/utils/FlowLogUtil.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/FlowLogUtil.java new file mode 100644 index 0000000000..be8cc25b5c --- /dev/null +++ b/java/osx/osx-core/src/main/java/com/osx/core/utils/FlowLogUtil.java @@ -0,0 +1,22 @@ +package com.osx.core.utils; + +import com.osx.api.context.Context; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class FlowLogUtil { + static Logger logger = LoggerFactory.getLogger("flow"); + + public static void printFlowLog(Context context) { + try { + logger.info(context.toString()); + }catch (Throwable ignore){ + } + + } + + + + + +} diff --git a/java/osx/core/src/main/java/com/osx/core/utils/GetSystemInfo.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/GetSystemInfo.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/utils/GetSystemInfo.java rename to java/osx/osx-core/src/main/java/com/osx/core/utils/GetSystemInfo.java diff --git a/java/osx/core/src/main/java/com/osx/core/utils/JsonUtil.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/JsonUtil.java similarity index 94% rename from java/osx/core/src/main/java/com/osx/core/utils/JsonUtil.java rename to java/osx/osx-core/src/main/java/com/osx/core/utils/JsonUtil.java index ef2d441cdc..ea183c7171 100644 --- a/java/osx/core/src/main/java/com/osx/core/utils/JsonUtil.java +++ b/java/osx/osx-core/src/main/java/com/osx/core/utils/JsonUtil.java @@ -58,8 +58,8 @@ public static T json2Object(String json, Class c) { T t = null; try { t = mapper.readValue(json, c); - } catch (IOException e) { - e.printStackTrace(); + } catch (IOException igore) { + } return t; } @@ -82,8 +82,7 @@ public static T json2List(String json, TypeReference typeReference) { T result = null; try { result = mapper.readValue(json, typeReference); - } catch (IOException e) { - e.printStackTrace(); + } catch (IOException igore) { } return result; } @@ -123,7 +122,7 @@ public static T object2Objcet(Object source, TypeReference tr) { } public static String formatJson(String jsonStr) { - return formatJson(jsonStr, " "); + return formatJson(jsonStr, "\t"); } /*** @@ -205,7 +204,7 @@ public static String pbToJson(MessageOrBuilder message) { public static void main(String[] args) { - String s = JsonUtil.formatJson("{\"route_table\":{\"default\":{\"default\":[{\"ip\":\"127.0.0.1\",\"port\":9999,\"useSSL\":false}]},\"10000\":{\"default\":[{\"ip\":\"127.0.0.1\",\"port\":8889}],\"serving\":[{\"ip\":\"127.0.0.1\",\"port\":8080}]},\"123\":[{\"host\":\"10.35.27.23\",\"port\":8888,\"useSSL\":false,\"negotiationType\":\"\",\"certChainFile\":\"\",\"privateKeyFile\":\"\",\"caFile\":\"\"}]},\"permission\":{\"default_allow\":true}}"); + String s = JsonUtil.formatJson("{\"route_table\":{\"default\":{\"default\":[{\"ip\":\"127.0.0.1\",\"port\":9999,\"useSSL\":false}]},\"10000\":{\"default\":[{\"ip\":\"127.0.0.1\",\"port\":8889}],\"serving\":[{\"ip\":\"127.0.0.1\",\"port\":8080}]},\"123\":[{\"host\":\"127.0.0.1\",\"port\":8888,\"useSSL\":false,\"negotiationType\":\"\",\"certChainFile\":\"\",\"privateKeyFile\":\"\",\"caFile\":\"\"}]},\"permission\":{\"default_allow\":true}}"); System.out.println(s); } diff --git a/java/osx/core/src/main/java/com/osx/core/utils/NetUtils.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/NetUtils.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/utils/NetUtils.java rename to java/osx/osx-core/src/main/java/com/osx/core/utils/NetUtils.java diff --git a/java/osx/osx-core/src/main/java/com/osx/core/utils/OSXCertUtils.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/OSXCertUtils.java new file mode 100644 index 0000000000..2123142f4d --- /dev/null +++ b/java/osx/osx-core/src/main/java/com/osx/core/utils/OSXCertUtils.java @@ -0,0 +1,147 @@ +package com.osx.core.utils; + + +import com.osx.core.config.MetaInfo; +import sun.misc.BASE64Decoder; +import sun.security.x509.X509CertImpl; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; +import java.io.*; +import java.security.*; +import java.security.cert.Certificate; +import java.security.cert.CertificateFactory; +import java.security.spec.PKCS8EncodedKeySpec; +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; + +/*** + * certificates type conversion + */ +public class OSXCertUtils { + private static final int I0 = 0; + private static final int I1 = 1; + private static final String type = "PKCS12"; + private static final AtomicInteger keyStoreCount = new AtomicInteger(1); + + /*** + * x509 certificate packaged into p12 certificate + * @param chain cert chain, issue cert +> superior cert +> ... + * @param privateKey issued cert private key + * @param filePath path to save p12 the cert + * @param alias alias + * @throws Exception NoCert, NoSuchAlgorithm , NoKeyStore, io + */ + public static void x509ToPkCS12(Certificate[] chain, Key privateKey, String filePath, String alias) throws Exception { + try (OutputStream os = new FileOutputStream(filePath)) { + KeyStore keyStore = KeyStore.getInstance(type); + keyStore.load(null); + keyStore.setKeyEntry(alias, privateKey, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray(), chain); + keyStore.store(os, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + } + } + + + /*** + * get x509 cert and private key by p12 cert + * @param filePath p12 cert file + * @param cs p1:storePassword p2:certPassword + * @return x509 certificate and private key + * @throws Exception NoCert, NoSuchAlgorithm , NoKeyStore, io + */ + public static X509AndKey getX509AndKeyByPkCS12(String filePath, String... cs) throws Exception { + try (InputStream is = new FileInputStream(filePath)) { + KeyStore keyStore = KeyStore.getInstance(type); + keyStore.load(is, toCharArray(I0, cs)); + String alias = keyStore.aliases().nextElement(); + return new X509AndKey(((X509CertImpl) keyStore.getCertificate(alias)), + ((PrivateKey) keyStore.getKey(alias, toCharArray(I1, cs)))); + } + } + + public static SSLContext getSSLContext(String caPath, String clientCertPath, String clientKeyPath) throws Exception { + KeyStore keyStore = getKeyStore(caPath, clientCertPath, clientKeyPath); + // Initialize the ssl context object + SSLContext sslContext = SSLContext.getInstance("SSL"); + TrustManager[] tm = {OsxX509TrustManager.getInstance(keyStore)}; + // Load client certificate + KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); + kmf.init(keyStore, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + sslContext.init(kmf.getKeyManagers(), tm, new SecureRandom()); + return sslContext; + } + + public static KeyStore getKeyStore(String caPath, String clientCertPath, String clientKeyPath) throws Exception { + KeyStore keyStore = KeyStore.getInstance("PKCS12"); + keyStore.load(null); + keyStore.setKeyEntry(MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_ALIAS, importPrivateKey(clientKeyPath), MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray(), + new Certificate[]{importCert(clientCertPath), importCert(caPath)}); + return keyStore; + } + + public static KeyStore getTrustStore(String keyStorePath, String trustStoreType) throws Exception { + KeyStore keyStore = KeyStore.getInstance(trustStoreType.toUpperCase()); + keyStore.load(new FileInputStream(new File(keyStorePath)), MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + return keyStore; + } + + public static String createKeyStore(String caPath, String clientCertPath, String clientKeyPath) throws Exception { + PrivateKey privateKey = importPrivateKey(clientKeyPath); +// Certificate[] certificates = {importCert(clientCertPath), importCert(caPath)}; + Certificate[] certificates = {importCert(clientCertPath), importCert(caPath)}; + String pfxPath = OSXCertUtils.getTempStorePath(); + File pfxFile = new File(pfxPath); + FileUtils.createNewFile(pfxFile); + OSXCertUtils.x509ToPkCS12(certificates, privateKey, pfxPath, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_ALIAS); + return pfxPath; + } + + public static Certificate importCert(String certFile) throws Exception { + try (FileInputStream certStream = new FileInputStream(certFile)) { + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + return cf.generateCertificate(certStream); + } + } + + // Import private key + public static PrivateKey importPrivateKey(String privateKeyFile) throws Exception { + try (FileInputStream keyStream = new FileInputStream(privateKeyFile)) { + String space = ""; + byte[] bytes = new byte[keyStream.available()]; + int length = keyStream.read(bytes); + String keyString = new String(bytes, 0, length); + if (keyString.startsWith("-----BEGIN PRIVATE KEY-----\n")) { + keyString = keyString.replace("-----BEGIN PRIVATE KEY-----\n", space).replace("-----END PRIVATE KEY-----", space); + } + PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(new BASE64Decoder().decodeBuffer(keyString)); + return KeyFactory.getInstance("RSA").generatePrivate(keySpec); + } + } + + + //determine whether the string is null and get the default string character array + private static char[] toCharArray(int index, String... str) { + return str.length <= index || str[index] == null ? MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray() : str[index].toCharArray(); + } + + public static String getTempStorePath(){ + return ""; + } + + /*** + * this class pack X509Certificate and privateKey + */ + public static class X509AndKey { + private final X509CertImpl x509Cert; + private final PrivateKey privateKey; + + public X509AndKey(X509CertImpl x509Certificate, PrivateKey privateKey) { + this.x509Cert = x509Certificate; + this.privateKey = privateKey; + } + + } +} diff --git a/java/osx/osx-core/src/main/java/com/osx/core/utils/OsxX509TrustManager.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/OsxX509TrustManager.java new file mode 100644 index 0000000000..5a02cf0d43 --- /dev/null +++ b/java/osx/osx-core/src/main/java/com/osx/core/utils/OsxX509TrustManager.java @@ -0,0 +1,247 @@ +package com.osx.core.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.*; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.NoSuchProviderException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; + +public class OsxX509TrustManager implements X509TrustManager { + private static final Logger logger = LoggerFactory.getLogger(OsxX509TrustManager.class); + public static final String tabs = "%2F", equalSign = "%3D"; + + private final X509TrustManager x509TrustManager; + + public OsxX509TrustManager(X509TrustManager x509TrustManager) { + this.x509TrustManager = x509TrustManager; + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) { + try { + if (this.x509TrustManager == null) return; + this.x509TrustManager.checkClientTrusted(chain, authType); + } catch (CertificateException exc) { + logger.error(exc.getMessage()); + } + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) { + // sunJSSEX509TrustManager.checkServerTrusted(chain, authType); +// if (checkServer) { +// for (X509Certificate x509Certificate : chain) { +// // Use ca certificate verify +// verify(caX509Certificate, x509Certificate); +// +// // Send ocsp request verify +// ocspVerify(x509Certificate); +// } +// } + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + if (this.x509TrustManager == null) return null; + return this.x509TrustManager.getAcceptedIssuers(); + } + + public static OsxX509TrustManager getInstance() { + return new OsxX509TrustManager(null); + } + + public static OsxX509TrustManager getInstance(KeyStore keyStore) throws NoSuchProviderException, NoSuchAlgorithmException, KeyStoreException { + X509TrustManager x509TrustManager = null; + TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509", "SunJSSE"); + tmf.init(keyStore); + TrustManager[] tms = tmf.getTrustManagers(); + for (TrustManager tm : tms) { + if (tm instanceof X509TrustManager) { + x509TrustManager = (X509TrustManager) tm; + break; + } + } + return new OsxX509TrustManager(x509TrustManager); + } + + // Verify that the certificate if expired, and is issued for the root certificate +// public static void verify(X509Certificate superiorCert, X509Certificate issueCert) throws CertificateException { +// try { +// issueCert.checkValidity(); +// issueCert.verify(superiorCert.getPublicKey()); +// } catch (Exception e) { +// throw new CertificateException(e); +// } +// } + + // Obtain ocsp service address from the certificate and verify the validity of the certificate +// public static void ocspVerify(X509Certificate x509Certificate) throws CertificateException { +// X509CertImpl x509Cert = (X509CertImpl) x509Certificate; +// AuthorityInfoAccessExtension accessExtension = x509Cert.getAuthorityInfoAccessExtension(); +// List accessDescriptions = accessExtension.getAccessDescriptions(); +// for (AccessDescription accessDescription : accessDescriptions) { +// String anObject = accessDescription.getAccessMethod().toString(); +// if ("ocsp".equals(anObject) || "1.3.6.1.5.5.7.48.1".equals(anObject)) { +// GeneralName accessLocation = accessDescription.getAccessLocation(); +// URI ocspUrl = ((URIName) accessLocation.getName()).getURI(); +// goSendOCSP(ocspUrl.toString(), x509Cert); +// } +// } +// } + + // Send ocsp request +// public static void goSendOCSP(String ocspUrl, X509CertImpl x509Certificate) throws CertificateException { +// try { +// URL url = new URL(ocspUrl + "/" + getOcspRequestData(x509Certificate)); +// HttpURLConnection urlConnection = (HttpURLConnection) url.openConnection(); +// urlConnection.setConnectTimeout(5000); +// urlConnection.setReadTimeout(15000); +// urlConnection.setRequestMethod("GET"); +// urlConnection.setDoOutput(true); +// urlConnection.setDoInput(true); +// urlConnection.setRequestProperty("Content-type", "application/ocsp-request"); +// +// try (InputStream br = urlConnection.getInputStream(); +// ByteArrayOutputStream aos = new ByteArrayOutputStream() +// ) { +// int len; +// byte[] bytes = new byte[br.available()]; +// while ((len = br.read(bytes)) != -1) { +// aos.write(bytes, 0, len); +// } +// OCSPResponse ocspResponse = new OCSPResponse(aos.toByteArray()); +// OCSPResponse.ResponseStatus responseStatus = ocspResponse.getResponseStatus(); +// +// if (!responseStatus.equals(OCSPResponse.ResponseStatus.SUCCESSFUL)) { +// throw new CertificateException("ocsp request failed, request state: " + responseStatus); +// } +// +// Set certIds = ocspResponse.getCertIds(); +// for (CertId certId : certIds) { +// // Date nextUpdate = singleResponse.getNextUpdate(); +// // CRLReason revocationReason = singleResponse.getRevocationReason(); +// // Date thisUpdate = singleResponse.getThisUpdate(); +// OCSPResponse.SingleResponse singleResponse = ocspResponse.getSingleResponse(certId); +// OCSP.RevocationStatus.CertStatus certStatus = singleResponse.getCertStatus(); +// System.out.println("server certificate serial number: " + certId.getSerialNumber().toString(16) + ", status: " + certStatus); +// +// if (!OCSP.RevocationStatus.CertStatus.GOOD.equals(certStatus)) { +// // throw new CertificateException("服务器验证失败, 证书状态: " + certStatus); +// } +// } +// +// +// } catch (Exception e) { +// throw new CertificateException(e); +// } +// } catch (IOException e) { +// throw new CertificateException(e); +// } +// } + + // get ocsp request bytes +// private static byte[] getOcspRequestBytesData(X509CertImpl x509Certificate) throws IOException { +// return new OCSPRequest(new CertId(x509Certificate, x509Certificate.getSerialNumberObject())).encodeBytes(); +// } + + // get Base64 encode ocsp request url string parameter +// private static String getOcspRequestData(X509CertImpl certificate) throws IOException { +// CertId certId = new CertId(certificate, certificate.getSerialNumberObject()); +// OCSPRequest request = new OCSPRequest(certId); +// String encodeBuffer = new BASE64Encoder().encodeBuffer(request.encodeBytes()); +// return encodeBuffer.replace("\r\n", "").replace("/", tabs).replace("=", equalSign); +// } + + // OCSPRequest +// private static class OCSPRequest { +// private static final Debug debug = Debug.getInstance("certpath"); +// private static final boolean dump; +// private final List certIds; +// private final List extensions; +// private byte[] nonce; +// +// public OCSPRequest(CertId certId) { +// this(Collections.singletonList(certId)); +// } +// +// public OCSPRequest(List certIdList) { +// this.certIds = certIdList; +// this.extensions = Collections.emptyList(); +// } +// +// public OCSPRequest(List certIdList, List extensionList) { +// this.certIds = certIdList; +// this.extensions = extensionList; +// } +// +// public byte[] encodeBytes() throws IOException { +// DerOutputStream fillDOS = new DerOutputStream(); +// DerOutputStream certIdDOS = new DerOutputStream(); +// +// for (CertId certId : this.certIds) { +// DerOutputStream encodeDos = new DerOutputStream(); +// certId.encode(encodeDos); +// certIdDOS.write((byte) 48, encodeDos); +// } +// +// fillDOS.write((byte) 48, certIdDOS); +// DerOutputStream extensionDos; +// DerOutputStream endDos; +// if (!this.extensions.isEmpty()) { +// extensionDos = new DerOutputStream(); +// +// for (java.security.cert.Extension extension : this.extensions) { +// extension.encode(extensionDos); +// if (extension.getId().equals(PKIXExtensions.OCSPNonce_Id.toString())) { +// this.nonce = extension.getValue(); +// } +// } +// +// endDos = new DerOutputStream(); +// endDos.write((byte) 48, extensionDos); +// fillDOS.write(DerValue.createTag((byte) -128, true, (byte) 2), endDos); +// } +// +// extensionDos = new DerOutputStream(); +// extensionDos.write((byte) 48, fillDOS); +// endDos = new DerOutputStream(); +// endDos.write((byte) 48, extensionDos); +// byte[] bytes = endDos.toByteArray(); +// if (dump) { +// HexDumpEncoder dumpEncoder = new HexDumpEncoder(); +// debug.println("OCSPRequest bytes...\n\n" + dumpEncoder.encode(bytes) + "\n"); +// } +// +// return bytes; +// } +// +// public List getCertIds() { +// return this.certIds; +// } +// +// public byte[] getNonce() { +// return this.nonce; +// } +// +// static { +// dump = debug != null && Debug.isOn("ocsp"); +// } +// } + + public static class HostnameVerifier2 implements HostnameVerifier { + + @Override + public boolean verify(String s, SSLSession sslSession) { + return true; + } + + public static HostnameVerifier2 getInstance() { + return new HostnameVerifier2(); + } + } +} diff --git a/java/osx/osx-core/src/main/java/com/osx/core/utils/PropertiesUtil.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/PropertiesUtil.java new file mode 100644 index 0000000000..741fb0876b --- /dev/null +++ b/java/osx/osx-core/src/main/java/com/osx/core/utils/PropertiesUtil.java @@ -0,0 +1,131 @@ +package com.osx.core.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.*; +import java.util.Properties; + + + +public final class PropertiesUtil +{ + public static Logger logger = LoggerFactory.getLogger(PropertiesUtil.class); + + public static Properties getProperties(String path) + { + Properties prop = new Properties(); + + loadProp(prop, path); + + return prop; + } + + private static void loadProp(Properties p, String conf) + { + InputStream is = getInputStream(conf); + + if(null != is) + { + try + { + p.load(is); + } + catch (IOException e) + { + logger.info("file not found!"); + } + finally + { + if(is != null) + { + try + { + is.close(); + } + catch (IOException e) + { + logger.info("stream close fail!"); + } + } + } + } + } + + //获取输入流 + private static InputStream getInputStream(String conf) + { + File file = new File(conf); + InputStream is = null; + try { + is = new BufferedInputStream(new FileInputStream(file)); + } catch (FileNotFoundException e) { + e.printStackTrace(); + } + return is; + } + + //获取输出流 + private static OutputStream getOutPutStream(String conf) + { + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + + OutputStream out = null; + + if(null != classLoader) + { + String filePath = classLoader.getResource(conf).getFile(); + try + { + out = new FileOutputStream(filePath); + } + catch (FileNotFoundException e) + { + logger.info("file not found!!!"); + } + } + return out; + } + + //根据key读取value + public static String getValue(Properties p, String key) + { + String value = p.getProperty(key); + + return value == null?"":value; + } + + //设置key=value + public static void setValue(String conf, String key, String value) + { + Properties p = getProperties(conf); + + OutputStream out = getOutPutStream(conf); + + p.setProperty(key, value); + + try + { + p.store(out, "set:"+key+"="+value); + } + catch (IOException e) + { + logger.info("set properties fail!!!"); + } + finally + { + if(out != null) + { + try + { + out.close(); + } + catch (IOException e) + { + logger.info("stream close fail!"); + } + } + } + } + +} \ No newline at end of file diff --git a/java/osx/core/src/main/java/com/osx/core/utils/RouterUtil.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/RouterUtil.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/utils/RouterUtil.java rename to java/osx/osx-core/src/main/java/com/osx/core/utils/RouterUtil.java diff --git a/java/osx/core/src/main/java/com/osx/core/utils/ServerUtil.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/ServerUtil.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/utils/ServerUtil.java rename to java/osx/osx-core/src/main/java/com/osx/core/utils/ServerUtil.java diff --git a/java/osx/core/src/main/java/com/osx/core/utils/ThreadPoolUtil.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/ThreadPoolUtil.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/utils/ThreadPoolUtil.java rename to java/osx/osx-core/src/main/java/com/osx/core/utils/ThreadPoolUtil.java diff --git a/java/osx/core/src/main/java/com/osx/core/utils/ToStringUtils.java b/java/osx/osx-core/src/main/java/com/osx/core/utils/ToStringUtils.java similarity index 100% rename from java/osx/core/src/main/java/com/osx/core/utils/ToStringUtils.java rename to java/osx/osx-core/src/main/java/com/osx/core/utils/ToStringUtils.java diff --git a/java/osx/pom.xml b/java/osx/pom.xml index 6ca280af42..9b1794af46 100644 --- a/java/osx/pom.xml +++ b/java/osx/pom.xml @@ -9,16 +9,9 @@ pom ${osx.version} - core - broker - - - - - - - - + osx-core + osx-broker + osx-api @@ -30,7 +23,7 @@ 1.51.1 1.18.24 3.21.12 - 5.26 + 0.6.1 1.6.1 1.7.36 @@ -57,15 +50,19 @@ 1.10.0 4.13.2 5.12.1 + 3.8.0 + - - - - - + + + com.lmax + disruptor + ${disruptor.version} + + org.slf4j slf4j-api @@ -121,6 +118,12 @@ ${jetty.version} + + + + + + commons-io commons-io @@ -209,6 +212,12 @@ ${flatbuffers-java.version} + + commons-net + commons-net + ${commons-net.version} + + com.google.flatbuffers flatbuffers-java-grpc @@ -224,19 +233,16 @@ grpc-core ${grpc.version} - io.grpc grpc-netty-shaded ${grpc.version} - io.grpc grpc-protobuf ${grpc.version} - io.grpc grpc-stub @@ -247,54 +253,41 @@ protobuf-java-util ${protobuf.version} - - - com.fasterxml.jackson.core jackson-annotations ${jackson.version} - com.fasterxml.jackson.core jackson-databind ${jackson.version} - io.grpc grpc-core ${grpc.version} - io.grpc grpc-netty-shaded ${grpc.version} - io.grpc grpc-protobuf ${grpc.version} - - io.grpc grpc-stub ${grpc.version} - - com.googlecode.protobuf-java-format protobuf-java-format ${protobuf-java-format.version} - - org.apache.httpcomponents httpclient @@ -384,9 +377,6 @@ - - - org.apache.maven.plugins diff --git a/java/osx/proto/osx.proto b/java/osx/proto/osx.proto index f75453ada1..b8aca166e1 100644 --- a/java/osx/proto/osx.proto +++ b/java/osx/proto/osx.proto @@ -51,9 +51,16 @@ enum Metadata { SourceComponentName = 2; // 源组件名称 TargetComponentName = 3; // 目标组件名称 TargetMethod = 4; // 目标方法 - MessageOffSet = 5; // 消息序列号 - InstanceId = 6; // 实例ID - Timestamp = 7; // 时间戳 + SourceMethod = 5; // 协议标志 + MessageOffSet = 6; // 消息序列号 + InstanceId = 7; // 实例ID + Timestamp = 8; // 时间戳 + MessageFlag = 9; // 消息标志 + MessageTopicBack = 10; // 接受应答消息队列 + RetryCount = 11; // 重试次数 + Timeout = 12; // 超时时间 + JobId = 13; //jobId + } // 通信传输层输入报文编码 @@ -70,6 +77,33 @@ message Outbound { string message = 4; // 状态说明 } +message PeekInbound { + string topic = 1; // optional 会话主题,相同信道具有唯一性,用于同一信道的传输隔离 +} + +message PopInbound { + string topic = 1; // optional 会话主题,相同信道具有唯一性,用于同一信道的传输隔离 + int32 timeout = 2; // optional 阻塞超时时间,默认120s +} + +message PushInbound{ + bytes payload = 1; // 二进制报文 + string topic = 2; // optional 会话主题,相同信道具有唯一性,用于同一信道的传输隔离 + map metadata = 3; // optional 保留参数,用于扩展性 +} + +message ReleaseInbound { + string topic = 1; // optional 会话主题,相同信道具有唯一性,用于同一信道的传输隔离 + int32 timeout = 2; // optional 阻塞超时时间, +} + +message TransportOutbound { + bytes payload = 1; // 二进制报文 + string code = 2; // 状态码 + string message = 3; // 状态说明 +} + + // 互联互通如果使用异步传输协议作为标准参考,Header会复用metadata传输互联互通协议报头,且metadata中会传输异步场景下的消息相关属性 // 互联互通如果使用其他协议作为参考标准,Header会复用metadata传输互联互通协议报头 // 互联互通如果使用GRPC作为参考标准,Header会复用HTTP2的报头传输互联互通协议报头 @@ -77,6 +111,12 @@ message Outbound { service PrivateTransferProtocol { rpc transport (stream Inbound) returns (stream Outbound); rpc invoke (Inbound) returns (Outbound); + rpc test(stream Inbound) returns (stream Outbound); + + rpc peek (PeekInbound) returns (TransportOutbound); + rpc pop (PopInbound) returns (TransportOutbound); + rpc push (PushInbound) returns (TransportOutbound); + rpc release (ReleaseInbound) returns (TransportOutbound); } @@ -85,3 +125,7 @@ service PrivateTransferProtocol { + + + + diff --git a/python/fate/arch/federation/osx/_mq_channel.py b/python/fate/arch/federation/osx/_mq_channel.py index cacef6d8b3..30f229c6cd 100644 --- a/python/fate/arch/federation/osx/_mq_channel.py +++ b/python/fate/arch/federation/osx/_mq_channel.py @@ -35,6 +35,7 @@ def __init__( self._namespace = namespace self._send_topic = send_topic self._receive_topic = receive_topic + self._index = 1 self._src_party_id = src_party_id self._src_role = src_role self._dst_party_id = dst_party_id @@ -51,7 +52,7 @@ def consume(self, offset=-1): LOGGER.debug(f"consume, offset={offset}, mq={self}") self._get_or_create_channel() meta = dict( - MessageTopic=self._send_topic, + MessageTopic=self._receive_topic, TechProviderCode="FATE", SourceNodeID=self._src_party_id, TargetNodeID=self._dst_party_id, @@ -64,9 +65,8 @@ def consume(self, offset=-1): inbound = osx_pb2.Inbound(metadata=meta) LOGGER.debug(f"consume, inbound={inbound}, mq={self}") result = self._stub.invoke(inbound) - LOGGER.debug(f"consume, result={result}, mq={self}") - print(result) - print(result.code) + LOGGER.debug(f"consume, result={result.code}, mq={self}") + return result @nretry @@ -94,7 +94,7 @@ def produce(self, body, properties): LOGGER.debug(f"produce body={body}, properties={properties}, mq={self}") self._get_or_create_channel() meta = dict( - MessageTopic=self._receive_topic, + MessageTopic=self._send_topic, TechProviderCode="FATE", SourceNodeID=self._src_party_id, TargetNodeID=self._dst_party_id, @@ -107,7 +107,11 @@ def produce(self, body, properties): inbound = osx_pb2.Inbound(metadata=meta, payload=msg.SerializeToString()) LOGGER.debug(f"produce inbound={inbound}, mq={self}") result = self._stub.invoke(inbound) - LOGGER.debug(f"produce result={result}, mq={self}") + + LOGGER.debug(f"produce {self._receive_topic} index {self._index} result={result.code}, mq={self}") + if result.code!="0": + raise RuntimeError(f"produce msg error ,code : {result.code} msg : {result.message}") + self._index+=1 return result @nretry @@ -115,7 +119,7 @@ def ack(self, offset): LOGGER.debug(f"ack offset={offset}, mq={self}") self._get_or_create_channel() meta = dict( - MessageTopic=self._send_topic, + MessageTopic=self._receive_topic, TechProviderCode="FATE", SourceNodeID=self._src_party_id, TargetNodeID=self._dst_party_id,