diff --git a/python/fate/arch/dataframe/conf/default_config.py b/python/fate/arch/dataframe/conf/default_config.py index 41b439b1f0..5b7aa0404b 100644 --- a/python/fate/arch/dataframe/conf/default_config.py +++ b/python/fate/arch/dataframe/conf/default_config.py @@ -13,5 +13,5 @@ # See the License for the specific language governing permissions and # limitations under the License. # -DATAFRAME_BLOCK_ROW_SIZE = 2**6 - +DATAFRAME_BLOCK_ROW_SIZE = 2**7 +BLOCK_COMPRESS_THRESHOLD = 5 diff --git a/python/fate/arch/dataframe/ops/_compress_block.py b/python/fate/arch/dataframe/ops/_compress_block.py index 3c71c5c1ee..c1f955c3eb 100644 --- a/python/fate/arch/dataframe/ops/_compress_block.py +++ b/python/fate/arch/dataframe/ops/_compress_block.py @@ -17,34 +17,44 @@ import torch from ..manager import BlockType from ..manager import DataManager +from ..conf.default_config import BLOCK_COMPRESS_THRESHOLD -def compress_blocks(block_table, data_manager: DataManager): - to_compress_block_loc, non_compress_block_changes = data_manager.compress_blocks() - if not to_compress_block_loc: +def compress_blocks(block_table, data_manager: DataManager, force_compress=False): + compressed_data_manager = data_manager.duplicate() + to_compress_block_loc, non_compress_block_changes = compressed_data_manager.compress_blocks() + + compress_block_size = 0 + for _, block_loc in to_compress_block_loc: + compress_block_size += len(block_loc) + + if not to_compress_block_loc or (not force_compress and compress_block_size <= BLOCK_COMPRESS_THRESHOLD): return block_table, data_manager def _compress(blocks): - ret_blocks = [[] for _ in range(data_manager.block_num)] + ret_blocks = [[] for _ in range(compressed_data_manager.block_num)] for src_bid, dst_bid in non_compress_block_changes.items(): ret_blocks[dst_bid] = blocks[src_bid] lines = len(blocks[0]) for dst_bid, block_loc in to_compress_block_loc: - field_len = len(data_manager.get_block(dst_bid).field_indexes) - block = data_manager.get_block(dst_bid) + block = compressed_data_manager.get_block(dst_bid) + field_len = len(block.field_indexes) + # TODO: empty block create logic should move to block_manager later, + # we pull it here as block_manager has more type like phe_tensor/pd.Index, which should not be considered in compressing if BlockType.is_tensor(block.block_type): - block_buf = torch.empty((lines, field_len), dtype=getattr(torch, block.block_type.value)) - else: block_buf = np.empty((lines, field_len), dtype=getattr(np, block.block_type.value)) + else: + block_buf = np.empty((lines, field_len), dtype=object) for src_bid, field_indexes in block_loc: block_buf[:, field_indexes] = blocks[src_bid] - ret_blocks[dst_bid] = block.convert_block(block_buf) + if isinstance(block_buf, np.ndarray): + ret_blocks[dst_bid] = torch.from_numpy(block_buf) return ret_blocks block_table = block_table.mapValues(_compress) - return block_table, data_manager + return block_table, compressed_data_manager diff --git a/python/fate/arch/dataframe/ops/_histogram.py b/python/fate/arch/dataframe/ops/_histogram.py index 4f1c75e0ab..a0e9a4234c 100644 --- a/python/fate/arch/dataframe/ops/_histogram.py +++ b/python/fate/arch/dataframe/ops/_histogram.py @@ -29,7 +29,7 @@ def hist(df: DataFrame, targets): data_manager = df.data_manager column_names = data_manager.infer_operable_field_names() - block_table, data_manager = _try_to_compress_table(df.block_table, data_manager) + block_table, data_manager = _try_to_compress_table(df.block_table, data_manager, force_compress=True) block_id = data_manager.infer_operable_blocks()[0] def _mapper(blocks, target, bid: int = None): @@ -47,7 +47,7 @@ def _reducer(l_histogram, r_histogram): def distributed_hist_stat(df: DataFrame, distributed_hist, position: DataFrame, targets: Union[dict, DataFrame]): - block_table, data_manager = _try_to_compress_table(df.block_table, df.data_manager) + block_table, data_manager = _try_to_compress_table(df.block_table, df.data_manager, force_compress=True) data_block_id = data_manager.infer_operable_blocks()[0] position_block_id = position.data_manager.infer_operable_blocks()[0] @@ -98,7 +98,7 @@ def _pack_with_targets(l_blocks, r_blocks): return distributed_hist.i_update(data_with_position) -def _try_to_compress_table(block_table, data_manager: DataManager): +def _try_to_compress_table(block_table, data_manager: DataManager, force_compress=False): block_indexes = data_manager.infer_operable_blocks() if len(block_indexes) == 1: return block_table, data_manager @@ -119,8 +119,6 @@ def _try_to_compress_table(block_table, data_manager: DataManager): to_promote_types.append((bid, block_type)) data_manager.promote_types(to_promote_types) - block_table, data_manager = compress_blocks(block_table, data_manager) + block_table, data_manager = compress_blocks(block_table, data_manager, force_compress=force_compress) return block_table, data_manager - -