diff --git a/python/fate/arch/dataframe/_dataframe.py b/python/fate/arch/dataframe/_dataframe.py index e56fd8a8ff..f1f96a4e05 100644 --- a/python/fate/arch/dataframe/_dataframe.py +++ b/python/fate/arch/dataframe/_dataframe.py @@ -101,17 +101,9 @@ def weight(self): @property def shape(self) -> "tuple": if self._count is None: - if self._sample_id_indexer: - items = self._sample_id_indexer.count() - elif self._match_id_indexer: - items = self._match_id_indexer.count() - else: - if self._block_table.count() == 0: - items = 0 - else: - items = self._block_table.mapValues(lambda block: 0 if block is None else len(block[0])).reduce( - lambda size1, size2: size1 + size2 - ) + items = 0 + for _, v in self._partition_order_mappings.items(): + items += v["end_index"] - v["start_index"] + 1 self._count = items return self._count, len(self._data_manager.schema.columns) diff --git a/python/fate/arch/dataframe/ops/_indexer.py b/python/fate/arch/dataframe/ops/_indexer.py index 0e8b3ea750..deb7f2b585 100644 --- a/python/fate/arch/dataframe/ops/_indexer.py +++ b/python/fate/arch/dataframe/ops/_indexer.py @@ -108,6 +108,10 @@ def _get_block_summary(kvs): start_index, acc_block_num = 0, 0 block_order_mappings = dict() + + if not block_summary: + return block_order_mappings + for blk_key, blk_size in sorted(block_summary.items()): block_num = (blk_size + block_row_size - 1) // block_row_size block_order_mappings[blk_key] = dict( @@ -241,6 +245,18 @@ def flatten_data(df: DataFrame, key_type="block_id", with_sample_id=True): df.data_manager.schema.sample_id_name, with_offset=False ) if with_sample_id else None + def _flatten_with_block_id_key(kvs): + for block_id, blocks in kvs: + for row_id in range(len(blocks[0])): + if with_sample_id: + yield (block_id, row_id), ( + blocks[sample_id_index][row_id], + [Block.transform_row_to_raw(block, row_id) for block in blocks] + ) + else: + yield (block_id, row_id), [Block.transform_row_to_raw(block, row_id) for block in blocks] + + """ def _flatten_with_block_id_key(block_id, blocks): for row_id in range(len(blocks[0])): if with_sample_id: @@ -250,9 +266,11 @@ def _flatten_with_block_id_key(block_id, blocks): ) else: yield (block_id, row_id), [Block.transform_row_to_raw(block, row_id) for block in blocks] + """ if key_type == "block_id": - return df.block_table.flatMap(_flatten_with_block_id_key) + return df.block_table.mapPartitions(_flatten_with_block_id_key, use_previous_behavior=False) + # return df.block_table.flatMap(_flatten_with_block_id_key) else: raise ValueError(f"Not Implement key_type={key_type} of flatten_data.") @@ -302,6 +320,9 @@ def loc_with_sample_id_replacement(df: DataFrame, indexer): row: (key=random_key, value=(sample_id, (src_block_id, src_offset)) """ + if indexer.count() == 0: + return df.empty_frame() + data_manager = df.data_manager partition_order_mappings = get_partition_order_by_raw_table(indexer, data_manager.block_row_size, @@ -331,7 +352,7 @@ def _aggregate(kvs): sample_id_index = data_manager.loc_block(data_manager.schema.sample_id_name, with_offset=False) block_num = data_manager.block_num - def _convert_to_block(kvs): + def _convert_to_row(kvs): ret_dict = {} for block_id, (blocks, block_indexer) in kvs: for src_row_id, sample_id, dst_block_id, dst_row_id in block_indexer: @@ -353,13 +374,24 @@ def _convert_to_block(kvs): return ret_dict.items() + def _convert_to_frame_block(blocks): + convert_blocks = [] + for idx, block_schema in enumerate(data_manager.blocks): + block_content = [row_data[1][idx] for row_data in blocks] + convert_blocks.append(block_schema.convert_block(block_content)) + + return convert_blocks + agg_indexer = indexer.mapReducePartitions(_aggregate, lambda l1, l2: l1 + l2) block_table = df.block_table.join(agg_indexer, lambda v1, v2: (v1, v2)) - block_table = block_table.mapReducePartitions(_convert_to_block, _merge_list) + block_table = block_table.mapReducePartitions(_convert_to_row, _merge_list) + block_table = block_table.mapValues(_convert_to_frame_block) + """ block_table = block_table.mapValues(lambda values: [v[1] for v in values]) from ._transformer import transform_list_block_to_frame_block block_table = transform_list_block_to_frame_block(block_table, df.data_manager) + """ return DataFrame( ctx=df._ctx,