Skip to content

Commit

Permalink
dataframe: optimizer indexer for psi
Browse files Browse the repository at this point in the history
Signed-off-by: mgqa34 <mgq3374541@163.com>
  • Loading branch information
mgqa34 committed Aug 14, 2023
1 parent 8eff556 commit b80b291
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
14 changes: 3 additions & 11 deletions python/fate/arch/dataframe/_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,9 @@ def weight(self):
@property
def shape(self) -> "tuple":
if self._count is None:
if self._sample_id_indexer:
items = self._sample_id_indexer.count()
elif self._match_id_indexer:
items = self._match_id_indexer.count()
else:
if self._block_table.count() == 0:
items = 0
else:
items = self._block_table.mapValues(lambda block: 0 if block is None else len(block[0])).reduce(
lambda size1, size2: size1 + size2
)
items = 0
for _, v in self._partition_order_mappings.items():
items += v["end_index"] - v["start_index"] + 1
self._count = items

return self._count, len(self._data_manager.schema.columns)
Expand Down
38 changes: 35 additions & 3 deletions python/fate/arch/dataframe/ops/_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def _get_block_summary(kvs):

start_index, acc_block_num = 0, 0
block_order_mappings = dict()

if not block_summary:
return block_order_mappings

for blk_key, blk_size in sorted(block_summary.items()):
block_num = (blk_size + block_row_size - 1) // block_row_size
block_order_mappings[blk_key] = dict(
Expand Down Expand Up @@ -241,6 +245,18 @@ def flatten_data(df: DataFrame, key_type="block_id", with_sample_id=True):
df.data_manager.schema.sample_id_name, with_offset=False
) if with_sample_id else None

def _flatten_with_block_id_key(kvs):
for block_id, blocks in kvs:
for row_id in range(len(blocks[0])):
if with_sample_id:
yield (block_id, row_id), (
blocks[sample_id_index][row_id],
[Block.transform_row_to_raw(block, row_id) for block in blocks]
)
else:
yield (block_id, row_id), [Block.transform_row_to_raw(block, row_id) for block in blocks]

"""
def _flatten_with_block_id_key(block_id, blocks):
for row_id in range(len(blocks[0])):
if with_sample_id:
Expand All @@ -250,9 +266,11 @@ def _flatten_with_block_id_key(block_id, blocks):
)
else:
yield (block_id, row_id), [Block.transform_row_to_raw(block, row_id) for block in blocks]
"""

if key_type == "block_id":
return df.block_table.flatMap(_flatten_with_block_id_key)
return df.block_table.mapPartitions(_flatten_with_block_id_key, use_previous_behavior=False)
# return df.block_table.flatMap(_flatten_with_block_id_key)
else:
raise ValueError(f"Not Implement key_type={key_type} of flatten_data.")

Expand Down Expand Up @@ -302,6 +320,9 @@ def loc_with_sample_id_replacement(df: DataFrame, indexer):
row: (key=random_key,
value=(sample_id, (src_block_id, src_offset))
"""
if indexer.count() == 0:
return df.empty_frame()

data_manager = df.data_manager
partition_order_mappings = get_partition_order_by_raw_table(indexer,
data_manager.block_row_size,
Expand Down Expand Up @@ -331,7 +352,7 @@ def _aggregate(kvs):
sample_id_index = data_manager.loc_block(data_manager.schema.sample_id_name, with_offset=False)
block_num = data_manager.block_num

def _convert_to_block(kvs):
def _convert_to_row(kvs):
ret_dict = {}
for block_id, (blocks, block_indexer) in kvs:
for src_row_id, sample_id, dst_block_id, dst_row_id in block_indexer:
Expand All @@ -353,13 +374,24 @@ def _convert_to_block(kvs):

return ret_dict.items()

def _convert_to_frame_block(blocks):
convert_blocks = []
for idx, block_schema in enumerate(data_manager.blocks):
block_content = [row_data[1][idx] for row_data in blocks]
convert_blocks.append(block_schema.convert_block(block_content))

return convert_blocks

agg_indexer = indexer.mapReducePartitions(_aggregate, lambda l1, l2: l1 + l2)
block_table = df.block_table.join(agg_indexer, lambda v1, v2: (v1, v2))
block_table = block_table.mapReducePartitions(_convert_to_block, _merge_list)
block_table = block_table.mapReducePartitions(_convert_to_row, _merge_list)
block_table = block_table.mapValues(_convert_to_frame_block)
"""
block_table = block_table.mapValues(lambda values: [v[1] for v in values])
from ._transformer import transform_list_block_to_frame_block
block_table = transform_list_block_to_frame_block(block_table, df.data_manager)
"""

return DataFrame(
ctx=df._ctx,
Expand Down

0 comments on commit b80b291

Please sign in to comment.