Skip to content

Commit

Permalink
edit selection (#4661)
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Wu <yolandawu131@gmail.com>
  • Loading branch information
nemirorox committed Jul 11, 2023
1 parent cef9b57 commit 5a2c72a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
5 changes: 5 additions & 0 deletions python/fate/components/components/hetero_feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def train(
if manual_param.keep_col is not None:
keep_col = [columns[anonymous_columns.index(col)] for col in manual_param.keep_col]
manual_param.keep_col = keep_col
# temp code start
iv_param = iv_param.dict()
statistic_param = statistic_param.dict()
manual_param = manual_param.dict()
# temp code end
input_models = [model.read() for model in input_models]
if role.is_guest:
selection = HeteroSelectionModuleGuest(method, select_col, input_models,
Expand Down
17 changes: 9 additions & 8 deletions python/fate/ml/feature_selection/hetero_feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

logger = logging.getLogger(__name__)

DEFAULT_METRIC = {"iv": ["iv"], "statistic": ["mean"]}
DEFAULT_METRIC = {"iv": ["iv"], "statistics": ["mean"]}


class HeteroSelectionModuleGuest(HeteroModule):
Expand Down Expand Up @@ -186,7 +186,7 @@ def fit(self, ctx: Context, train_data, validate_data=None) -> None:
param=self.iv_param,
model=model,
keep_one=self.keep_one)
elif filter_type == "statistic":
elif filter_type == "statistics":
model = self.isometric_model_dict.get("statistics", None)
if model is None:
raise ValueError(f"Cannot find statistics model in input, please check")
Expand Down Expand Up @@ -400,13 +400,13 @@ def fit(self, ctx: Context, train_data, validate_data=None):
"""metric_names = self.param.metrics or []"""
# temp code ends
# local only
if self.method in ["statistic"]:
if self.method in ["statistics"]:
for metric_name in metric_names:
if metric_name not in self.model.get("metrics", {}):
if metric_name not in self.model.get("meta", {}).get("metrics", {}):
raise ValueError(f"metric {metric_name} not found in given statistic model with metrics: "
f"{metric_names}, please check")

metrics_all = pd.DataFrame(self.model.get("metrics_summary", {})).loc[metric_names]
f"{self.model.get('metrics', {})}, please check")
model_data = self.model.get("data", {})
metrics_all = pd.DataFrame(model_data.get("metrics_summary", {})).loc[metric_names]
self._all_metrics = metrics_all
missing_col = set(self._prev_selected_mask[self._prev_selected_mask].index). \
difference(set(metrics_all.columns))
Expand All @@ -431,7 +431,8 @@ def fit(self, ctx: Context, train_data, validate_data=None):
# host does not perform local iv selection
if ctx.local[0] == "host":
return
iv_metrics = pd.Series(self.model["metrics_summary"]["iv"])
model_data = self.model.get("data", {})
iv_metrics = pd.Series(model_data["metrics_summary"]["iv"])
metrics_all = pd.DataFrame(iv_metrics).T.rename({0: "iv"}, axis=0)
self._all_metrics = metrics_all
# works for multiple iv filters
Expand Down

0 comments on commit 5a2c72a

Please sign in to comment.