Skip to content

Commit

Permalink
fix federated metrics computation with categorical features(#4660)
Browse files Browse the repository at this point in the history
disable cumsum for event count(#4660)
update he param

Signed-off-by: Yu Wu <yolandawu131@gmail.com>
  • Loading branch information
nemirorox committed Sep 8, 2023
1 parent 01c3741 commit 21a7d82
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 17 deletions.
7 changes: 4 additions & 3 deletions examples/benchmark_quality/lr/pipeline-lr-binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

import argparse

from fate_test.utils import parse_summary_result

from fate_client.pipeline import FateFlowPipeline
from fate_client.pipeline.components.fate import CoordinatedLR, PSI
from fate_client.pipeline.components.fate import Evaluation
from fate_client.pipeline.interface import DataWarehouseChannel
from fate_client.pipeline.utils import test_utils
from fate_test.utils import extract_data, parse_summary_result


def main(config="../../config.yaml", param="./breast_config.yaml", namespace=""):
Expand Down Expand Up @@ -88,14 +89,14 @@ def main(config="../../config.yaml", param="./breast_config.yaml", namespace="")
pipeline.compile()
pipeline.fit()

lr_0_data = pipeline.get_task_info("lr_0").get_output_data()["train_output_data"]
"""lr_0_data = pipeline.get_task_info("lr_0").get_output_data()["train_output_data"]
lr_1_data = pipeline.get_task_info("lr_1").get_output_data()["test_output_data"]
lr_0_score = extract_data(lr_0_data, "predict_result")
lr_0_label = extract_data(lr_0_data, "y")
lr_1_score = extract_data(lr_1_data, "predict_result")
lr_1_label = extract_data(lr_1_data, "y")
lr_0_score_label = extract_data(lr_0_data, "predict_result", keep_id=True)
lr_1_score_label = extract_data(lr_1_data, "predict_result", keep_id=True)
lr_1_score_label = extract_data(lr_1_data, "predict_result", keep_id=True)"""

result_summary = parse_summary_result(pipeline.get_task_info("evaluation_0").get_output_metric()[0]["data"])
print(f"result_summary: {result_summary}")
Expand Down
2 changes: 1 addition & 1 deletion python/fate/components/core/params/_he_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


class HEParam(pydantic.BaseModel):
kind: string_choice(["paillier"])
kind: string_choice(["paillier", "ou", "mock"])
key_length: int = 1024


Expand Down
24 changes: 11 additions & 13 deletions python/fate/ml/feature_binning/hetero_feature_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,12 @@ def compute_federated_metrics(self, ctx: Context, binned_data):
coder = ctx.guest.get("coder")
columns = binned_data.schema.columns.to_list()
# logger.info(f"self.bin_col: {self.bin_col}")
anonymous_col_bin = [binned_data.schema.anonymous_columns[columns.index(col)] for col in self.bin_col]
to_compute_col = self.bin_col + self.category_col
anonymous_col_bin = [binned_data.schema.anonymous_columns[columns.index(col)] for col in to_compute_col]

ctx.guest.put("anonymous_col_bin", anonymous_col_bin)
encrypt_y = ctx.guest.get("enc_y")
# event count:
to_compute_col = self.bin_col + self.category_col
feature_bin_sizes = [self._bin_obj._bin_count_dict[col] for col in self.bin_col]
if self.category_col:
for col in self.category_col:
Expand All @@ -200,19 +200,20 @@ def compute_federated_metrics(self, ctx: Context, binned_data):
hist_targets = binned_data.create_frame()
hist_targets["event_count"] = encrypt_y
hist_targets["non_event_count"] = 1
hist_schema = {"event_count": {"type": "paillier",
hist_schema = {"event_count": {"type": "ciphertext",
"stride": 1,
"pk": pk,
"evaluator": evaluator,
"coder": coder
},
"non_event_count": {"type": "tensor",
"non_event_count": {"type": "plaintext",
"stride": 1,
"dtype": torch.int32}
}
hist = HistogramBuilder(num_node=1,
feature_bin_sizes=feature_bin_sizes,
value_schemas=hist_schema)
value_schemas=hist_schema,
enable_cumsum=False)
event_non_event_count_hist = to_compute_data.distributed_hist_stat(histogram_builder=hist,
targets=hist_targets)
event_non_event_count_hist.i_sub_on_key("non_event_count", "event_count")
Expand Down Expand Up @@ -293,7 +294,7 @@ def fit(self, ctx: Context, train_data, validate_data=None, skip_none=False):

if self.method == "quantile":
q = list(np.arange(0, 1, 1 / self.n_bins)) + [1.0]
split_pt_df = select_data.quantile(q=q, relative_error=self.relative_error)
split_pt_df = select_data.quantile(q=q, relative_error=self.relative_error).drop(0)
elif self.method == "bucket":
split_pt_df = select_data.qcut(q=self.n_bins)
elif self.method == "manual":
Expand All @@ -311,7 +312,6 @@ def __get_col_bin_count(col):
self._bin_count_dict = bin_count.to_dict()

def bucketize_data(self, train_data):
# logger.debug(f"split pt dict: {self._split_pt_dict}")
binned_df = train_data.bucketize(boundaries=self._split_pt_dict)
return binned_df

Expand All @@ -328,11 +328,9 @@ def compute_all_col_metrics(self, event_non_event_count_hist, columns):
col_event_count = pd.Series(
{bin_num: int(bin_count.data) for bin_num, bin_count in event_count_dict[col_name].items()}
)
col_event_count = col_event_count - col_event_count.shift(1).fillna(0)
col_non_event_count = pd.Series(
{bin_num: int(bin_count.data) for bin_num, bin_count in non_event_count_dict[col_name].items()}
)
col_non_event_count = col_non_event_count - col_non_event_count.shift(1).fillna(0)
if total_event_count is None:
total_event_count = col_event_count.sum() or 1
total_non_event_count = col_non_event_count.sum() or 1
Expand Down Expand Up @@ -377,17 +375,17 @@ def compute_metrics(self, binned_data):
hist_targets = binned_data.create_frame()
hist_targets["event_count"] = binned_data.label
hist_targets["non_event_count"] = 1
hist_schema = {"event_count": {"type": "tensor",
hist_schema = {"event_count": {"type": "plaintext",
"stride": 1,
"dtype": torch.int32},
"non_event_count": {"type": "tensor",
"non_event_count": {"type": "plaintext",
"stride": 1,
"dtype": torch.int32}
}
hist = HistogramBuilder(num_node=1,
feature_bin_sizes=feature_bin_sizes,
value_schemas=hist_schema)
df = to_compute_data.as_pd_df()
value_schemas=hist_schema,
enable_cumsum=False)
event_non_event_count_hist = to_compute_data.distributed_hist_stat(histogram_builder=hist,
targets=hist_targets)
event_non_event_count_hist.i_sub_on_key("non_event_count", "event_count")
Expand Down

0 comments on commit 21a7d82

Please sign in to comment.