Skip to content

Commit

Permalink
edit schema check(#4668)
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Wu <yolandawu131@gmail.com>
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
nemirorox authored and sagewe committed Aug 9, 2023
1 parent a385717 commit f3bf22b
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions python/fate/ml/preprocessing/union.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ def __init__(self, axis=0):
self.axis = axis

def fit(self, ctx: Context, train_data_list):
label_name_list = [data.schema.label_name for data in train_data_list]
if sum([name != label_name_list[0] for name in label_name_list]):
raise ValueError(f"Data sets should all have the same label_name for union.")

sample_id_name_list = [data.schema.sample_id_name for data in train_data_list]
if sum([name != sample_id_name_list[0] for name in sample_id_name_list]):
raise ValueError(f"Data sets should all have the same sample_id_name for union.")
Expand All @@ -39,13 +35,13 @@ def fit(self, ctx: Context, train_data_list):
raise ValueError(f"Data sets should all have the same match_id_name for union.")

if self.axis == 0:
data_cols = set()
for data in train_data_list:
if data_cols:
if set(data.schema.columns) != data_cols:
raise ValueError(f"Data sets should all have the same columns for union on 0 axis.")
else:
data_cols = set(data.schema.columns)
label_name_list = [data.schema.label_name for data in train_data_list]
if sum([name != label_name_list[0] for name in label_name_list]):
raise ValueError(f"Data sets should all have the same label_name for union.")

column_name_list = [set(data.schema.columns) for data in train_data_list]
if sum([col_names != column_name_list[0] for col_names in column_name_list]):
raise ValueError(f"Data sets should all have the same columns for union on 0 axis.")
result_data = DataFrame.vstack(train_data_list)
elif self.axis == 1:
if sum([data.label is not None for data in train_data_list]) > 1:
Expand Down

0 comments on commit f3bf22b

Please sign in to comment.