Skip to content

Commit

Permalink
Merge pull request #5106 from FederatedAI/feature-2.0.0-beta-datafram…
Browse files Browse the repository at this point in the history
…e_refact

Feature 2.0.0 beta dataframe refact
  • Loading branch information
mgqa34 authored Aug 31, 2023
2 parents a798bc6 + b5b6807 commit 5c5bccf
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 65 deletions.
2 changes: 1 addition & 1 deletion python/fate/arch/histogram/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def i_update(self, data, k=None) -> "ShuffledHistogram":
ShuffledHistogram, the shuffled(if seed is not None) histogram
"""
if k is None:
k = data.partitions
k = data.partitions ** 2
mapper = get_partition_hist_build_mapper(
self._node_size, self._feature_bin_sizes, self._value_schemas, self._seed, k
)
Expand Down
46 changes: 10 additions & 36 deletions python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import DecisionTree, Node, _get_sample_on_local_nodes, _update_sample_pos
from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import DecisionTree, Node, _update_sample_pos_on_local_nodes, _merge_sample_pos
from fate.ml.ensemble.learner.decision_tree.tree_core.hist import SBTHistogramBuilder
from fate.ml.ensemble.learner.decision_tree.tree_core.splitter import FedSBTSplitter
from fate.ml.ensemble.learner.decision_tree.tree_core.loss import get_task_info
Expand Down Expand Up @@ -108,47 +108,21 @@ def _update_sample_pos(self, ctx, cur_layer_nodes: List[Node], sample_pos: DataF

sitename = ctx.local.party[0] + '_' + ctx.local.party[1]
data_with_pos = DataFrame.hstack([data, sample_pos])
map_func = functools.partial(_get_sample_on_local_nodes, cur_layer_node=cur_layer_nodes, node_map=node_map, sitename=sitename)
local_sample_idx = data_with_pos.apply_row(map_func)
# local_samples = data_with_pos[local_sample_idx.as_tensor()]
local_samples = data_with_pos.iloc(local_sample_idx)
logger.info('{}/{} samples on local nodes'.format(len(local_samples), len(data)))
if len(local_samples) == 0:
updated_sample_pos = None
else:
"""
updated_sample_pos = sample_pos.loc(local_samples.get_indexer(target="sample_id"), preserve_order=True).create_frame()
update_func = functools.partial(_update_sample_pos, cur_layer_node=cur_layer_nodes, node_map=node_map)
map_rs = local_samples.apply_row(update_func)
updated_sample_pos["node_idx"] = map_rs # local_samples.apply_row(update_func)
"""
update_func = functools.partial(_update_sample_pos, cur_layer_node=cur_layer_nodes, node_map=node_map)
updated_sample_pos = local_samples.create_frame()
updated_sample_pos["node_idx"] = local_samples.apply_row(update_func)
map_func = functools.partial(_update_sample_pos_on_local_nodes, cur_layer_node=cur_layer_nodes, node_map=node_map, sitename=sitename)
updated_sample_pos = data_with_pos.apply_row(map_func, columns=["g_on_local", "g_node_idx"])

# synchronize sample pos
host_update_sample_pos = ctx.hosts.get('updated_data')
new_sample_pos = sample_pos.empty_frame()

merge_func = functools.partial(_merge_sample_pos)
for host_data in host_update_sample_pos:
if host_data[0]: # True
pos_data, pos_index = host_data[1]
tmp_frame = sample_pos.create_frame()
tmp_frame = tmp_frame.loc(pos_index, preserve_order=True)
tmp_frame['node_idx'] = pos_data
new_sample_pos = DataFrame.vstack([new_sample_pos, tmp_frame])

if updated_sample_pos is not None:
if len(updated_sample_pos) == len(data): # all samples are on local
new_sample_pos = updated_sample_pos
else:
logger.info('stack new sample pos, guest len {}, host len {}'.format(len(updated_sample_pos), len(new_sample_pos)))
new_sample_pos = DataFrame.vstack([updated_sample_pos, new_sample_pos])
else:
new_sample_pos = new_sample_pos # all samples are on host
updated_sample_pos = DataFrame.hstack([updated_sample_pos, host_data]).apply_row(
merge_func,
columns=["g_on_local", "g_node_idx"]
)

# share new sample position with all hosts
# ctx.hosts.put('new_sample_pos', (new_sample_pos.as_tensor(), new_sample_pos.get_indexer(target='sample_id')))
new_sample_pos = updated_sample_pos.create_frame(columns=["g_node_idx"])
new_sample_pos.rename(columns={"g_node_idx": "node_idx"})
ctx.hosts.put('new_sample_pos', new_sample_pos)
self.sample_pos = new_sample_pos

Expand Down
32 changes: 4 additions & 28 deletions python/fate/ml/ensemble/learner/decision_tree/hetero/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import DecisionTree, Node, _get_sample_on_local_nodes, _update_sample_pos, FeatureImportance
from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import DecisionTree, Node, _update_sample_pos_on_local_nodes, FeatureImportance
from fate.ml.ensemble.learner.decision_tree.tree_core.hist import SBTHistogramBuilder, DistributedHistogram
from fate.ml.ensemble.learner.decision_tree.tree_core.splitter import FedSBTSplitter
from fate.arch.histogram.histogram import ShuffledHistogram
Expand Down Expand Up @@ -81,34 +81,10 @@ def _update_sample_pos(self, ctx, cur_layer_nodes: List[Node], sample_pos: DataF

sitename = ctx.local.party[0] + '_' + ctx.local.party[1]
data_with_pos = DataFrame.hstack([data, sample_pos])
map_func = functools.partial(_get_sample_on_local_nodes, cur_layer_node=cur_layer_nodes, node_map=node_map, sitename=sitename)
# local_sample_idx = data_with_pos.apply_row(map_func).as_tensor()
# local_samples = data_with_pos[local_sample_idx]
local_sample_idx = data_with_pos.apply_row(map_func)
local_samples = data_with_pos.iloc(local_sample_idx)
logger.info('{} samples on local nodes'.format(len(local_samples)))

if len(local_samples) == 0:
updated_sample_pos = None
else:
update_func = functools.partial(_update_sample_pos, cur_layer_node=cur_layer_nodes, node_map=node_map)
updated_sample_pos = local_samples.create_frame()
updated_sample_pos["node_idx"] = local_samples.apply_row(update_func)
map_func = functools.partial(_update_sample_pos_on_local_nodes, cur_layer_node=cur_layer_nodes, node_map=node_map, sitename=sitename)
update_sample_pos = data_with_pos.apply_row(map_func, columns=["h_on_local", "h_node_idx"])

# synchronize sample pos
if updated_sample_pos is None:
update_data = (False, None)
else:
pos_data = updated_sample_pos.as_tensor()
pos_index = updated_sample_pos.get_indexer(target='sample_id')
update_data = (True, (pos_data, pos_index))
ctx.guest.put('updated_data', update_data)
"""
new_pos_data, new_pos_indexer = ctx.guest.get('new_sample_pos')
new_sample_pos = sample_pos.create_frame()
new_sample_pos = new_sample_pos.loc(new_pos_indexer, preserve_order=True)
new_sample_pos['node_idx'] = new_pos_data
"""
ctx.guest.put('updated_data', update_sample_pos)
new_sample_pos = ctx.guest.get('new_sample_pos')

return new_sample_pos
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,21 @@ def _get_sample_on_local_nodes(s: pd.Series, cur_layer_node: List[Node], node_ma
return on_local_node


def _update_sample_pos_on_local_nodes(s: pd.Series, cur_layer_node: List[Node], node_map: dict, sitename):
on_local_node = _get_sample_on_local_nodes(s, cur_layer_node, node_map, sitename)
if not on_local_node:
return False, -1
else:
return True, _update_sample_pos(s, cur_layer_node, node_map, sitename)


def _merge_sample_pos(s: pd.Series):
if s['g_on_local']:
return s['g_on_local'], s['g_node_idx']
else:
return s['h_on_local'], s['h_node_idx']


def _convert_sample_pos_to_score(s: pd.Series, tree_nodes: List[Node]):

node_idx = s[0]
Expand Down

0 comments on commit 5c5bccf

Please sign in to comment.