Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dataframe: speed up loc op #5092

Merged
merged 1 commit into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/fate/arch/dataframe/_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def hist(self, targets):

return hist(self, targets)

def distributed_hist_stat(self, distributed_hist, position: "DataFrame", targets: dict):
def distributed_hist_stat(self, distributed_hist, position: "DataFrame", targets: Union[dict, "DataFrame"]):
from .ops._histogram import distributed_hist_stat

return distributed_hist_stat(self, distributed_hist, position, targets)
Expand Down Expand Up @@ -529,12 +529,12 @@ def copy(self) -> "DataFrame":
)

@classmethod
def from_flatten_data(cls, ctx, flatten_table, data_manager) -> "DataFrame":
def from_flatten_data(cls, ctx, flatten_table, data_manager, key_type) -> "DataFrame":
"""
key=random_key, value=(sample_id, data)
"""
from .ops._indexer import transform_flatten_data_to_df
return transform_flatten_data_to_df(ctx, flatten_table, data_manager)
return transform_flatten_data_to_df(ctx, flatten_table, data_manager, key_type)

@classmethod
def hstack(cls, stacks: List["DataFrame"]) -> "DataFrame":
Expand Down Expand Up @@ -583,4 +583,4 @@ def __convert_to_table(self, target_name):
def data_overview(self, num=100):
from .ops._data_overview import collect_data

return collect_data(self, num=100)
return collect_data(self, num=num)
192 changes: 130 additions & 62 deletions python/fate/arch/dataframe/ops/_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
#
import functools
import numpy as np

from ..manager import Block, DataManager
from .._dataframe import DataFrame
Expand Down Expand Up @@ -197,90 +196,96 @@ def _merge_list(lhs, rhs):
return ret


def loc(df: DataFrame, indexer, target, preserve_order=False):
self_indexer = df.get_indexer(target)
if preserve_order:
indexer = self_indexer.join(indexer, lambda lhs, rhs: (lhs, rhs))
else:
indexer = self_indexer.join(indexer, lambda lhs, rhs: (lhs, lhs))
# def loc(df: DataFrame, indexer, target, preserve_order=False):
# self_indexer = df.get_indexer(target)
# if preserve_order:
# indexer = self_indexer.join(indexer, lambda lhs, rhs: (lhs, rhs))
# else:
# indexer = self_indexer.join(indexer, lambda lhs, rhs: (lhs, lhs))

if indexer.count() == 0:
return df.empty_frame()
# if indexer.count() == 0:
# return df.empty_frame()

agg_indexer = aggregate_indexer(indexer)
# agg_indexer = aggregate_indexer(indexer)

if not preserve_order:
def _convert_block(blocks, retrieval_indexes):
row_indexes = [retrieval_index[0] for retrieval_index in retrieval_indexes]
return [Block.retrieval_row(block, row_indexes) for block in blocks]
# if not preserve_order:
# def _convert_block(blocks, retrieval_indexes):
# row_indexes = [retrieval_index[0] for retrieval_index in retrieval_indexes]
# return [Block.retrieval_row(block, row_indexes) for block in blocks]

block_table = df.block_table.join(agg_indexer, _convert_block)
else:
def _convert_to_row(kvs):
ret_dict = {}
for block_id, (blocks, block_indexer) in kvs:
"""
block_indexer: row_id, (new_block_id, new_row_id)
"""
flat_blocks = [Block.transform_block_to_list(block) for block in blocks]
block_num = len(flat_blocks)
for src_row_id, (dst_block_id, dst_row_id) in block_indexer:
if dst_block_id not in ret_dict:
ret_dict[dst_block_id] = []

ret_dict[dst_block_id].append(
(dst_row_id, [flat_blocks[i][src_row_id] for i in range(block_num)])
)

for dst_block_id, value_list in ret_dict.items():
yield dst_block_id, sorted(value_list)

block_table = df.block_table.join(agg_indexer, lambda lhs, rhs: (lhs, rhs))
block_table = block_table.mapReducePartitions(_convert_to_row, _merge_list)

_convert_to_frame_block_func = functools.partial(_convert_to_frame_block, data_manager=df.data_manager)
block_table = block_table.mapValues(_convert_to_frame_block_func)

partition_order_mappings = get_partition_order_mappings_by_block_table(block_table, df.data_manager.block_row_size)
return DataFrame(
df._ctx,
block_table,
partition_order_mappings,
df.data_manager.duplicate())
# block_table = df.block_table.join(agg_indexer, _convert_block)
# else:
# def _convert_to_row(kvs):
# ret_dict = {}
# for block_id, (blocks, block_indexer) in kvs:
# """
# block_indexer: row_id, (new_block_id, new_row_id)
# """
# flat_blocks = [Block.transform_block_to_list(block) for block in blocks]
# block_num = len(flat_blocks)
# for src_row_id, (dst_block_id, dst_row_id) in block_indexer:
# if dst_block_id not in ret_dict:
# ret_dict[dst_block_id] = []

# ret_dict[dst_block_id].append(
# (dst_row_id, [flat_blocks[i][src_row_id] for i in range(block_num)])
# )

# for dst_block_id, value_list in ret_dict.items():
# yield dst_block_id, sorted(value_list)

# block_table = df.block_table.join(agg_indexer, lambda lhs, rhs: (lhs, rhs))
# block_table = block_table.mapReducePartitions(_convert_to_row, _merge_list)

# _convert_to_frame_block_func = functools.partial(_convert_to_frame_block, data_manager=df.data_manager)
# block_table = block_table.mapValues(_convert_to_frame_block_func)

# partition_order_mappings = get_partition_order_mappings_by_block_table(block_table, df.data_manager.block_row_size)
# return DataFrame(
# df._ctx,
# block_table,
# partition_order_mappings,
# df.data_manager.duplicate())


def flatten_data(df: DataFrame, key_type="block_id", with_sample_id=True):
"""
key_type="block_id":
key=(block_id, block_offset), value=data_row
key_type="sample_id":
key=sample_id, value=data_row
"""
sample_id_index = df.data_manager.loc_block(
df.data_manager.schema.sample_id_name, with_offset=False
) if with_sample_id else None
) if (with_sample_id or key_type == "sample_id") else None

def _flatten_with_block_id_key(kvs):
def _flatten(kvs):
for block_id, blocks in kvs:
flat_blocks = [Block.transform_block_to_list(block) for block in blocks]
block_num = len(flat_blocks)
for row_id in range(len(blocks[0])):
if with_sample_id:
yield (block_id, row_id), (
flat_blocks[sample_id_index][row_id],
[flat_blocks[i][row_id] for i in range(block_num)]
)
else:
yield (block_id, row_id), [flat_blocks[i][row_id] for i in range(block_num)]

if key_type == "block_id":
return df.block_table.mapPartitions(_flatten_with_block_id_key, use_previous_behavior=False)
if key_type == "block_id":
for row_id in range(len(blocks[0])):
if with_sample_id:
yield (block_id, row_id), (
flat_blocks[sample_id_index][row_id],
[flat_blocks[i][row_id] for i in range(block_num)]
)
else:
yield (block_id, row_id), [flat_blocks[i][row_id] for i in range(block_num)]
else:
for row_id in range(len(blocks[0])):
yield flat_blocks[sample_id_index][row_id], [flat_blocks[i][row_id] for i in range(block_num)]

if key_type in ["block_id", "sample_id"]:
return df.block_table.mapPartitions(_flatten, use_previous_behavior=False)
else:
raise ValueError(f"Not Implement key_type={key_type} of flatten_data.")


def transform_flatten_data_to_df(ctx, flatten_table, data_manager: DataManager):
def transform_flatten_data_to_df(ctx, flatten_table, data_manager: DataManager, key_type):
partition_order_mappings = get_partition_order_by_raw_table(flatten_table,
data_manager.block_row_size,
key_type="block_id")
key_type=key_type)
block_num = data_manager.block_num

def _convert_to_blocks(kvs):
Expand Down Expand Up @@ -316,6 +321,69 @@ def _convert_to_blocks(kvs):
)


def loc(df: DataFrame, indexer, target="sample_id", preserve_order=False):
"""
indexer: table, key=sample_id, value=(block_id, block_offset)
"""
if target != "sample_id":
raise ValueError(f"Only target=sample_id is supported, but target={target} is found")
flatten_table = flatten_data(df, key_type="sample_id")
if not preserve_order:
flatten_table = flatten_table.join(indexer, lambda v1, v2: v1)
if not flatten_table.count():
return df.empty_frame()
return transform_flatten_data_to_df(df._ctx, flatten_table, df.data_manager, key_type="sample_id")
else:
flatten_table_with_dst_indexer = flatten_table.join(indexer, lambda v1, v2: (v2[0], (v2[1], v1)))
if not flatten_table_with_dst_indexer.count():
return df.empty_frame()

def _aggregate(kvs):
values = [value for key, value in kvs]
values.sort()
i = 0
l = len(values)
while i < l:
j = i + 1
while j < l and values[j][0] == values[i][0]:
j += 1

yield values[i][0], [values[k][1] for k in range(i, j)]

i = j

data_manager = df.data_manager.duplicate()
block_num = data_manager.block_num

def _to_blocks(values):
block_size = len(values)
ret_blocks = [[None] * block_size for _ in range(block_num)]

for row_id, row_data in values:
for j in range(block_num):
ret_blocks[j][row_id] = row_data[j]

for idx, block_schema in enumerate(data_manager.blocks):
ret_blocks[idx] = block_schema.convert_block(ret_blocks[idx])

return ret_blocks

agg_data = flatten_table_with_dst_indexer.mapReducePartitions(_aggregate, lambda v1, v2: v1 + v2)
block_table = agg_data.mapValues(_to_blocks)

partition_order_mappings = get_partition_order_mappings_by_block_table(
block_table,
block_row_size=data_manager.block_row_size
)

return DataFrame(
df._ctx,
block_table=block_table,
partition_order_mappings=partition_order_mappings,
data_manager=data_manager
)


def loc_with_sample_id_replacement(df: DataFrame, indexer):
"""
indexer: table,
Expand Down
2 changes: 1 addition & 1 deletion python/fate/arch/protocol/psi/ecdh/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def guest_run(ctx, df: DataFrame, curve_type="curve25519", **kwargs):

intersect_guest_data = intersect_with_offset_ids.mapValues(lambda v: v[0])

guest_df = DataFrame.from_flatten_data(ctx, intersect_guest_data, df.data_manager)
guest_df = DataFrame.from_flatten_data(ctx, intersect_guest_data, df.data_manager, key_type="block_id")
ctx.metrics.log_metrics({"intersect_count": guest_df.shape[0]}, name="intersect_id_count", type="custom")

"""
Expand Down