From 7c054aff3d12fdc0ccfbb938d741a38814b1c9d8 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Tue, 4 Jul 2023 19:25:31 +0800 Subject: [PATCH] dataframe fix: fix cmp between tensor and ndarray/pd.Series Signed-off-by: mgqa34 --- python/fate/arch/dataframe/ops/_cmp.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/fate/arch/dataframe/ops/_cmp.py b/python/fate/arch/dataframe/ops/_cmp.py index d1187b8f5d..4097136c74 100644 --- a/python/fate/arch/dataframe/ops/_cmp.py +++ b/python/fate/arch/dataframe/ops/_cmp.py @@ -17,6 +17,7 @@ import torch from .._dataframe import DataFrame from .._dataframe import DataManager +from ..manager import BlockType from .utils.operators import binary_operate from .utils.series_align import series_to_ndarray from ._compress_block import compress_blocks @@ -42,10 +43,13 @@ def cmp_operate(lhs: DataFrame, rhs, op) -> "DataFrame": rhs = rhs.reshape(-1) field_indexes = [data_manager.get_field_offset(name) for name in column_names] field_indexes_mappings = dict(zip(field_indexes, range(len(field_indexes)))) - rhs_blocks = [np.array([]) for i in range(data_manager.block_num)] + rhs_blocks = [np.array([]) for _ in range(data_manager.block_num)] for bid in block_indexes: indexer = [field_indexes_mappings[field] for field in data_manager.get_block(bid).field_indexes] - rhs_blocks[bid] = rhs[indexer] + if BlockType.is_tensor(data_manager.get_block(bid).block_type): + rhs_blocks[bid] = torch.Tensor(rhs[indexer]) + else: + rhs_blocks[bid] = rhs[indexer] block_table = binary_operate(lhs.block_table, rhs_blocks, op, block_indexes) @@ -64,7 +68,7 @@ def cmp_operate(lhs: DataFrame, rhs, op) -> "DataFrame": ] block_table = _cmp_dfs(lhs.block_table, rhs.block_table, op, lhs_block_loc, rhs_block_loc, - block_indexes, indexers) + block_indexes, indexers) else: raise ValueError(f"Not implement comparison of rhs type={type(rhs)}") @@ -76,6 +80,7 @@ def cmp_operate(lhs: DataFrame, rhs, op) -> "DataFrame": data_manager ) + def _merge_bool_blocks(block_table, data_manager: DataManager, block_indexes): """ all blocks are bool type, they should be merge into one blocks @@ -83,7 +88,7 @@ def _merge_bool_blocks(block_table, data_manager: DataManager, block_indexes): dst_data_manager = data_manager.duplicate() to_promote_types = [] for bid in block_indexes: - to_promote_types.append((bid, torch.bool)) + to_promote_types.append((bid, BlockType.bool)) dst_data_manager.promote_types(to_promote_types) dst_block_table, dst_data_manager = compress_blocks(block_table, dst_data_manager)