Skip to content

Commit

Permalink
rename & update predict tools
Browse files Browse the repository at this point in the history
Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Sep 8, 2023
1 parent 6d72e40 commit c6e3ad0
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/fate/ml/utils/predict_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def add_ids(df: pd.DataFrame, match_id: pd.DataFrame, sample_id: pd.DataFrame):
return df


def to_fate_df(ctx, sample_id_name, match_id_name, result_df: pd.DataFrame):
def to_dist_df(ctx, sample_id_name, match_id_name, result_df: pd.DataFrame):

if LABEL in result_df:
reader = PandasReader(
Expand All @@ -61,7 +61,7 @@ def compute_predict_details(dataframe: DataFrame, task_type: Literal['binary', '

if not isinstance(dataframe, DataFrame):
raise ValueError('dataframe must be a fate DataFrame, but got {}'.format(type(dataframe)))
if dataframe.schema.label_name is not None:
if dataframe.schema.label_name is not None and dataframe.schema.label_name != LABEL:
dataframe.rename(label_name=LABEL)
assert PREDICT_SCORE in dataframe.schema.columns, 'column {} is not found in input dataframe'.format(PREDICT_SCORE)

Expand Down Expand Up @@ -138,7 +138,7 @@ def array_to_predict_df(

df[sample_id_name] = sample_ids.flatten()
df[match_id_name] = match_ids.flatten()
fate_df = to_fate_df(ctx, sample_id_name, match_id_name, df)
fate_df = to_dist_df(ctx, sample_id_name, match_id_name, df)
predict_result = compute_predict_details(fate_df, task_type, classes, threshold)

return predict_result

0 comments on commit c6e3ad0

Please sign in to comment.