Skip to content

Commit

Permalink
dataframe: update to support apply_row ret is phe_tensor
Browse files Browse the repository at this point in the history
Signed-off-by: mgqa34 <mgq3374541@163.com>
  • Loading branch information
mgqa34 committed Aug 23, 2023
1 parent 025b0e1 commit 9b6a58c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
12 changes: 8 additions & 4 deletions python/fate/arch/dataframe/manager/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,12 @@ def __init__(self, *args, **kwargs):

@staticmethod
def convert_block(block):
try:
return torch.tensor(block, dtype=torch.int32)
except ValueError:
return torch.tensor(np.array(block, dtype="int32"), dtype=torch.int32)
if isinstance(block, torch.Tensor):
if block.dtype == torch.int32:
return block.clone().detach()
else:
return block.to(torch.int32)
return torch.tensor(np.array(block, dtype="int32"), dtype=torch.int32)


class Int64Block(Block):
Expand Down Expand Up @@ -365,6 +367,8 @@ def set_extra_kwargs(self, pk, evaluator, coder, dtype, device):

@staticmethod
def convert_block(block):
if isinstance(block, list):
block = block[0].cat(block[1:])
return block

def convert_to_phe_tensor(self, block, shape):
Expand Down
13 changes: 12 additions & 1 deletion python/fate/arch/dataframe/ops/_apply_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ def _apply(blocks, func=None, src_field_names=None,
ret_blocks[idx] = blocks[bid]

for idx, bid in enumerate(block_indexes):
ret_blocks[bid] = dm.blocks[bid].convert_block(apply_blocks[idx])
if dm.blocks[bid].is_phe_tensor():
single_value = apply_blocks[idx][0][0]
dm.blocks[bid].set_extra_kwargs(pk=single_value.pk,
evaluator=single_value.evaluator,
coder=single_value.coder,
dtype=single_value.dtype,
device=single_value.device)
ret = [v[0]._data for v in apply_blocks[idx]]
ret_blocks[bid] = dm.blocks[bid].convert_block(ret)
# ret_blocks[bid] = dm.blocks[bid].convert_to_phe_tensor(ret, shape=(len(ret), 1))
else:
ret_blocks[bid] = dm.blocks[bid].convert_block(apply_blocks[idx])

return ret_blocks, dm
11 changes: 8 additions & 3 deletions python/fate/arch/dataframe/ops/_set_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,15 @@ def _append_multi(blocks, item_list, bid_list=None, dm: DataManager=None):

return ret_blocks

def _append_df(l_blocks, r_blocks, r_blocks_loc=None):
def _append_df(l_blocks, r_blocks, r_blocks_loc=None, dm=None):
ret_blocks = [block for block in l_blocks]
l_bid = len(ret_blocks)
for bid, offset in r_blocks_loc:
ret_blocks.append(r_blocks[bid][:, [offset]])
if dm.blocks[bid].is_phe_tensor():
ret_blocks.append(r_blocks[bid])
else:
ret_blocks.append(r_blocks[bid][:, [offset]])
l_bid += 1

return ret_blocks

Expand Down Expand Up @@ -128,7 +133,7 @@ def _append_phe_tensor(l_blocks, r_tensor):
raise ValueError("Setitem with rhs=DataFrame must have equal len keys")
data_manager.append_columns(keys, block_types)

_append_func = functools.partial(_append_df, r_blocks_loc=operable_blocks_loc)
_append_func = functools.partial(_append_df, r_blocks_loc=operable_blocks_loc, dm=data_manager)
block_table = df.block_table.join(items.block_table, _append_func)
elif isinstance(items, DTensor):
meta_data = items.shardings._data.mapValues(
Expand Down

0 comments on commit 9b6a58c

Please sign in to comment.