From 1a213f1e7eafd26366f34fab4e05ca58faba6cb0 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Mon, 31 Jul 2023 14:54:08 +0800 Subject: [PATCH] rename union & add schema check(#4668) Signed-off-by: Yu Wu --- examples/pipeline/test_union.py | 20 +++--- python/fate/components/components/__init__.py | 6 +- .../components/{feature_union.py => union.py} | 6 +- python/fate/ml/preprocessing/__init__.py | 2 +- python/fate/ml/preprocessing/feature_union.py | 42 ----------- python/fate/ml/preprocessing/union.py | 70 +++++++++++++++++++ 6 files changed, 87 insertions(+), 59 deletions(-) rename python/fate/components/components/{feature_union.py => union.py} (92%) delete mode 100644 python/fate/ml/preprocessing/feature_union.py create mode 100644 python/fate/ml/preprocessing/union.py diff --git a/examples/pipeline/test_union.py b/examples/pipeline/test_union.py index ec104dff05..9680764e1b 100644 --- a/examples/pipeline/test_union.py +++ b/examples/pipeline/test_union.py @@ -13,23 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import FeatureUnion +from fate_client.pipeline.components.fate import Union from fate_client.pipeline.interface import DataWarehouseChannel pipeline = FateFlowPipeline().set_roles(guest="9999") -feature_union_0 = FeatureUnion("feature_union_0", - runtime_roles=["guest"], - input_data_list=[DataWarehouseChannel(name="breast_hetero_guest_sid", - namespace="experiment"), - DataWarehouseChannel(name="breast_hetero_guest_sid", - namespace="experiment")], - axis=0) +union_0 = Union("union_0", + runtime_roles=["guest"], + input_data_list=[DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment_sid"), + DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment_sid")], + axis=0) -pipeline.add_task(feature_union_0) +pipeline.add_task(union_0) pipeline.compile() print(pipeline.get_dag()) pipeline.fit() -print(pipeline.get_task_info("feature_union_0").get_output_data()) +print(pipeline.get_task_info("union_0").get_output_data()) diff --git a/python/fate/components/components/__init__.py b/python/fate/components/components/__init__.py index b86a8e2798..2610b47b42 100644 --- a/python/fate/components/components/__init__.py +++ b/python/fate/components/components/__init__.py @@ -121,10 +121,10 @@ def hetero_feature_selection(self): return hetero_feature_selection @_lazy_cpn - def feature_union(self): - from .feature_union import feature_union + def union(self): + from .union import union - return feature_union + return union @_lazy_cpn def sample(self): diff --git a/python/fate/components/components/feature_union.py b/python/fate/components/components/union.py similarity index 92% rename from python/fate/components/components/feature_union.py rename to python/fate/components/components/union.py index 05fd263927..12f652d1f7 100644 --- a/python/fate/components/components/feature_union.py +++ b/python/fate/components/components/union.py @@ -18,7 +18,7 @@ @cpn.component(roles=[GUEST, HOST], provider="fate") -def feature_union( +def union( ctx: Context, role: Role, input_data_list: cpn.dataframe_inputs(roles=[GUEST, HOST]), @@ -26,13 +26,13 @@ def feature_union( desc="axis along which concatenation is performed, 0 for row-wise, 1 for column-wise"), output_data: cpn.dataframe_output(roles=[GUEST, HOST]) ): - from fate.ml.preprocessing import FeatureUnion + from fate.ml.preprocessing import Union data_list = [] for data in input_data_list: data = data.read() data_list.append(data) sub_ctx = ctx.sub_ctx("train") - union_obj = FeatureUnion(axis) + union_obj = Union(axis) output_df = union_obj.fit(sub_ctx, data_list) output_data.write(output_df) diff --git a/python/fate/ml/preprocessing/__init__.py b/python/fate/ml/preprocessing/__init__.py index c74e409eea..77cdb94f60 100644 --- a/python/fate/ml/preprocessing/__init__.py +++ b/python/fate/ml/preprocessing/__init__.py @@ -14,4 +14,4 @@ # limitations under the License. from .feature_scale import FeatureScale -from .feature_union import FeatureUnion +from .union import Union diff --git a/python/fate/ml/preprocessing/feature_union.py b/python/fate/ml/preprocessing/feature_union.py deleted file mode 100644 index 1b429a25ff..0000000000 --- a/python/fate/ml/preprocessing/feature_union.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Copyright 2023 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from fate.arch import Context -from fate.arch.dataframe import DataFrame -from ..abc.module import Module - -logger = logging.getLogger(__name__) - - -class FeatureUnion(Module): - def __init__(self, axis=0): - self.axis = axis - - def fit(self, ctx: Context, train_data_list): - if self.axis == 0: - result_data = DataFrame.vstack(train_data_list) - elif self.axis == 1: - col_set = set() - for data in train_data_list: - data_cols = set(data.schema.columns) - if col_set.intersection(data_cols): - raise ValueError(f"column name conflict: {col_set.intersection(data_cols)}. " - f"Please check input data") - col_set.update(data_cols) - result_data = DataFrame.hstack(train_data_list) - else: - raise ValueError(f"axis must be 0 or 1, but got {self.axis}") - return result_data diff --git a/python/fate/ml/preprocessing/union.py b/python/fate/ml/preprocessing/union.py new file mode 100644 index 0000000000..bd28e361ac --- /dev/null +++ b/python/fate/ml/preprocessing/union.py @@ -0,0 +1,70 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from ..abc.module import Module + +logger = logging.getLogger(__name__) + + +class Union(Module): + def __init__(self, axis=0): + self.axis = axis + + def fit(self, ctx: Context, train_data_list): + sample_id_name, match_id_name = None, None + if self.axis == 0: + data_cols = set() + for data in train_data_list: + if sample_id_name: + if data.schema.sample_id_name != sample_id_name: + raise ValueError(f"Data sets should all have the same sample_id_name for union.") + else: + sample_id_name = data.schema.sample_id_name + if match_id_name: + if data.schema.match_id_name != match_id_name: + raise ValueError(f"Data sets should all have the same match_id_name for union.") + else: + match_id_name = data.schema.match_id_name + if data_cols: + if set(data.schema.columns) != data_cols: + raise ValueError(f"Data sets should all have the same columns for union on 0 axis.") + else: + data_cols = set(data.schema.columns) + result_data = DataFrame.vstack(train_data_list) + elif self.axis == 1: + col_set = set() + for data in train_data_list: + if sample_id_name: + if data.schema.sample_id_name != sample_id_name: + raise ValueError(f"Data sets should all have the same sample_id_name for union.") + else: + sample_id_name = data.schema.sample_id_name + if match_id_name: + if data.schema.match_id_name != match_id_name: + raise ValueError(f"Data sets should all have the same match_id_name for union.") + else: + match_id_name = data.schema.match_id_name + data_cols = set(data.schema.columns) + if col_set.intersection(data_cols): + raise ValueError(f"column name conflict: {col_set.intersection(data_cols)}. " + f"Please check input data") + col_set.update(data_cols) + result_data = DataFrame.hstack(train_data_list) + else: + raise ValueError(f"axis must be 0 or 1, but got {self.axis}") + return result_data