diff --git a/python/fate/arch/histogram/histogram.py b/python/fate/arch/histogram/histogram.py index 1a6da1f89a..107ff0c69d 100644 --- a/python/fate/arch/histogram/histogram.py +++ b/python/fate/arch/histogram/histogram.py @@ -467,7 +467,7 @@ def i_update(self, data, k=None) -> "ShuffledHistogram": ShuffledHistogram, the shuffled(if seed is not None) histogram """ if k is None: - k = data.partitions + k = data.partitions ** 2 mapper = get_partition_hist_build_mapper( self._node_size, self._feature_bin_sizes, self._value_schemas, self._seed, k ) diff --git a/python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py b/python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py index ca37ec13e0..053b120cb8 100644 --- a/python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py +++ b/python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import DecisionTree, Node, _get_sample_on_local_nodes, _update_sample_pos +from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import DecisionTree, Node, _update_sample_pos_on_local_nodes, _merge_sample_pos from fate.ml.ensemble.learner.decision_tree.tree_core.hist import SBTHistogramBuilder from fate.ml.ensemble.learner.decision_tree.tree_core.splitter import FedSBTSplitter from fate.ml.ensemble.learner.decision_tree.tree_core.loss import get_task_info @@ -108,47 +108,21 @@ def _update_sample_pos(self, ctx, cur_layer_nodes: List[Node], sample_pos: DataF sitename = ctx.local.party[0] + '_' + ctx.local.party[1] data_with_pos = DataFrame.hstack([data, sample_pos]) - map_func = functools.partial(_get_sample_on_local_nodes, cur_layer_node=cur_layer_nodes, node_map=node_map, sitename=sitename) - local_sample_idx = data_with_pos.apply_row(map_func) - # local_samples = data_with_pos[local_sample_idx.as_tensor()] - local_samples = data_with_pos.iloc(local_sample_idx) - logger.info('{}/{} samples on local nodes'.format(len(local_samples), len(data))) - if len(local_samples) == 0: - updated_sample_pos = None - else: - """ - updated_sample_pos = sample_pos.loc(local_samples.get_indexer(target="sample_id"), preserve_order=True).create_frame() - update_func = functools.partial(_update_sample_pos, cur_layer_node=cur_layer_nodes, node_map=node_map) - map_rs = local_samples.apply_row(update_func) - updated_sample_pos["node_idx"] = map_rs # local_samples.apply_row(update_func) - """ - update_func = functools.partial(_update_sample_pos, cur_layer_node=cur_layer_nodes, node_map=node_map) - updated_sample_pos = local_samples.create_frame() - updated_sample_pos["node_idx"] = local_samples.apply_row(update_func) + map_func = functools.partial(_update_sample_pos_on_local_nodes, cur_layer_node=cur_layer_nodes, node_map=node_map, sitename=sitename) + updated_sample_pos = data_with_pos.apply_row(map_func, columns=["g_on_local", "g_node_idx"]) # synchronize sample pos host_update_sample_pos = ctx.hosts.get('updated_data') - new_sample_pos = sample_pos.empty_frame() + merge_func = functools.partial(_merge_sample_pos) for host_data in host_update_sample_pos: - if host_data[0]: # True - pos_data, pos_index = host_data[1] - tmp_frame = sample_pos.create_frame() - tmp_frame = tmp_frame.loc(pos_index, preserve_order=True) - tmp_frame['node_idx'] = pos_data - new_sample_pos = DataFrame.vstack([new_sample_pos, tmp_frame]) - - if updated_sample_pos is not None: - if len(updated_sample_pos) == len(data): # all samples are on local - new_sample_pos = updated_sample_pos - else: - logger.info('stack new sample pos, guest len {}, host len {}'.format(len(updated_sample_pos), len(new_sample_pos))) - new_sample_pos = DataFrame.vstack([updated_sample_pos, new_sample_pos]) - else: - new_sample_pos = new_sample_pos # all samples are on host + updated_sample_pos = DataFrame.hstack([updated_sample_pos, host_data]).apply_row( + merge_func, + columns=["g_on_local", "g_node_idx"] + ) - # share new sample position with all hosts - # ctx.hosts.put('new_sample_pos', (new_sample_pos.as_tensor(), new_sample_pos.get_indexer(target='sample_id'))) + new_sample_pos = updated_sample_pos.create_frame(columns=["g_node_idx"]) + new_sample_pos.rename(columns={"g_node_idx": "node_idx"}) ctx.hosts.put('new_sample_pos', new_sample_pos) self.sample_pos = new_sample_pos diff --git a/python/fate/ml/ensemble/learner/decision_tree/hetero/host.py b/python/fate/ml/ensemble/learner/decision_tree/hetero/host.py index 1d5cb3e10f..a4fcdfc932 100644 --- a/python/fate/ml/ensemble/learner/decision_tree/hetero/host.py +++ b/python/fate/ml/ensemble/learner/decision_tree/hetero/host.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import DecisionTree, Node, _get_sample_on_local_nodes, _update_sample_pos, FeatureImportance +from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import DecisionTree, Node, _update_sample_pos_on_local_nodes, FeatureImportance from fate.ml.ensemble.learner.decision_tree.tree_core.hist import SBTHistogramBuilder, DistributedHistogram from fate.ml.ensemble.learner.decision_tree.tree_core.splitter import FedSBTSplitter from fate.arch.histogram.histogram import ShuffledHistogram @@ -81,34 +81,10 @@ def _update_sample_pos(self, ctx, cur_layer_nodes: List[Node], sample_pos: DataF sitename = ctx.local.party[0] + '_' + ctx.local.party[1] data_with_pos = DataFrame.hstack([data, sample_pos]) - map_func = functools.partial(_get_sample_on_local_nodes, cur_layer_node=cur_layer_nodes, node_map=node_map, sitename=sitename) - # local_sample_idx = data_with_pos.apply_row(map_func).as_tensor() - # local_samples = data_with_pos[local_sample_idx] - local_sample_idx = data_with_pos.apply_row(map_func) - local_samples = data_with_pos.iloc(local_sample_idx) - logger.info('{} samples on local nodes'.format(len(local_samples))) - - if len(local_samples) == 0: - updated_sample_pos = None - else: - update_func = functools.partial(_update_sample_pos, cur_layer_node=cur_layer_nodes, node_map=node_map) - updated_sample_pos = local_samples.create_frame() - updated_sample_pos["node_idx"] = local_samples.apply_row(update_func) + map_func = functools.partial(_update_sample_pos_on_local_nodes, cur_layer_node=cur_layer_nodes, node_map=node_map, sitename=sitename) + update_sample_pos = data_with_pos.apply_row(map_func, columns=["h_on_local", "h_node_idx"]) - # synchronize sample pos - if updated_sample_pos is None: - update_data = (False, None) - else: - pos_data = updated_sample_pos.as_tensor() - pos_index = updated_sample_pos.get_indexer(target='sample_id') - update_data = (True, (pos_data, pos_index)) - ctx.guest.put('updated_data', update_data) - """ - new_pos_data, new_pos_indexer = ctx.guest.get('new_sample_pos') - new_sample_pos = sample_pos.create_frame() - new_sample_pos = new_sample_pos.loc(new_pos_indexer, preserve_order=True) - new_sample_pos['node_idx'] = new_pos_data - """ + ctx.guest.put('updated_data', update_sample_pos) new_sample_pos = ctx.guest.get('new_sample_pos') return new_sample_pos diff --git a/python/fate/ml/ensemble/learner/decision_tree/tree_core/decision_tree.py b/python/fate/ml/ensemble/learner/decision_tree/tree_core/decision_tree.py index 0b7ead89e9..a5bc6044a9 100644 --- a/python/fate/ml/ensemble/learner/decision_tree/tree_core/decision_tree.py +++ b/python/fate/ml/ensemble/learner/decision_tree/tree_core/decision_tree.py @@ -182,6 +182,21 @@ def _get_sample_on_local_nodes(s: pd.Series, cur_layer_node: List[Node], node_ma return on_local_node +def _update_sample_pos_on_local_nodes(s: pd.Series, cur_layer_node: List[Node], node_map: dict, sitename): + on_local_node = _get_sample_on_local_nodes(s, cur_layer_node, node_map, sitename) + if not on_local_node: + return False, -1 + else: + return True, _update_sample_pos(s, cur_layer_node, node_map, sitename) + + +def _merge_sample_pos(s: pd.Series): + if s['g_on_local']: + return s['g_on_local'], s['g_node_idx'] + else: + return s['h_on_local'], s['h_node_idx'] + + def _convert_sample_pos_to_score(s: pd.Series, tree_nodes: List[Node]): node_idx = s[0]