Skip to content

Commit

Permalink
dataframe: vstack support union with multiple column types cross diff…
Browse files Browse the repository at this point in the history
…erent dataframes

Signed-off-by: mgqa34 <mgq3374541@163.com>
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
mgqa34 authored and sagewe committed Aug 9, 2023
1 parent f45fc90 commit a7d6e50
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 43 deletions.
8 changes: 7 additions & 1 deletion python/fate/arch/dataframe/manager/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ def __lt__(self, other):

return False

def __gt__(self, other):
if self == other:
return False

return other < self

@staticmethod
def get_block_type(data_type):
if isinstance(data_type, np.dtype):
Expand Down Expand Up @@ -406,7 +412,7 @@ def split_fields(self, field_indexes, block_types):
for block_id, field_with_offset_list in block_field_maps.items():
if len(self._blocks[block_id].field_indexes) == len(field_with_offset_list):
if len(field_with_offset_list) == 1:
self._blocks[block_id] = Block.get_block_by_type(block_types)(
self._blocks[block_id] = Block.get_block_by_type(block_type)(
self._blocks[block_id].field_indexes,
should_compress=self._blocks[block_id].should_compress
)
Expand Down
113 changes: 71 additions & 42 deletions python/fate/arch/dataframe/ops/_dimension_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ..manager.data_manager import DataManager
from ._compress_block import compress_blocks
from ._indexer import get_partition_order_by_raw_table
from ._promote_types import promote_partial_block_types
from ._set_item import set_item
from fate.arch.tensor import DTensor

Expand Down Expand Up @@ -59,71 +60,99 @@ def hstack(data_frames: List["DataFrame"]) -> "DataFrame":
def vstack(data_frames: List["DataFrame"]) -> "DataFrame":
frame_0 = data_frames[0]
data_frames = list(filter(lambda df: df.shape[0], data_frames))
if not data_frames:
return frame_0
if len(data_frames) <= 1:
return frame_0 if not data_frames else data_frames[0]

if len(data_frames[0]) == 1:
return data_frames[0]

def _align_blocks(blocks, src_fields_loc=None, src_dm: DataManager = None, dst_dm: DataManager = None):
ret_blocks = []
lines = None
def _align_blocks(blocks, align_fields_loc=None, full_block_migrate_set=None, dst_dm: DataManager = None):
ret_blocks, lines = [], None
for dst_bid, block in enumerate(dst_dm.blocks):
field_indexes = block.field_indexes
src_bid = src_fields_loc[field_indexes[0]][0]
if src_dm.blocks[src_bid].field_indexes == field_indexes:
ret_blocks.append(blocks[src_bid])
_field_indexes = block.field_indexes
_src_bid = align_fields_loc[_field_indexes[0]][0]
if _src_bid in full_block_migrate_set:
ret_blocks.append(blocks[_src_bid])
else:
block_buf = []
_align_block = []
lines = len(blocks[0]) if lines is None else lines

for lid in range(lines):
row = []
for field_index in field_indexes:
src_bid, offset = src_fields_loc[field_index]
if isinstance(blocks[src_bid], torch.Tensor):
row.append(blocks[src_bid][lid][offset].item())
else:
row.append(blocks[src_bid][lid][offset])
for _field_index in _field_indexes:
_src_bid, _offset = align_fields_loc[_field_index]
row.append(blocks[_src_bid][lid][offset].item() if isinstance(blocks[_src_bid], torch.Tensor)
else blocks[_src_bid][lid][offset])

block_buf.append(row)
_align_block.append(row)

ret_blocks.append(dst_dm.blocks[dst_bid].convert_block(block_buf))
ret_blocks.append(dst_dm.blocks[dst_bid].convert_block(_align_block))

return ret_blocks

l_df = data_frames[0]
data_manager = l_df.data_manager
data_manager = l_df.data_manager.duplicate()
l_fields_loc = data_manager.get_fields_loc()
l_field_names = data_manager.get_field_name_list()
l_field_types = [data_manager.get_block(_bid).block_type for _bid, _ in l_fields_loc]
l_block_table = l_df.block_table
type_change = False
for r_df in data_frames[1:]:
if set(l_df.schema.columns) != set(r_df.schema.columns):
raise ValueError("vstack of dataframes should have same schemas")

for idx, field_name in enumerate(l_field_names):
block_type = r_df.data_manager.get_block(
r_df.data_manager.loc_block(field_name, with_offset=False)).block_type
if block_type > l_field_types[idx]:
l_field_types[idx] = block_type
type_change = True

if type_change:
changed_fields, changed_block_types, changed_fields_loc = [], [], []
changed_block_types = []
for idx in range(len(l_field_names)):
field_name, block_type, (bid, offset) = l_field_names[idx], l_field_types[idx], l_fields_loc[idx]
if block_type != data_manager.get_block(bid).block_type:
changed_fields.append(field_name)
changed_block_types.append(block_type)
changed_fields_loc.append((bid, offset))

narrow_blocks, dst_blocks = data_manager.split_columns(changed_fields, changed_block_types)
l_block_table = promote_partial_block_types(l_block_table, narrow_blocks=narrow_blocks, dst_blocks=dst_blocks,
data_manager=data_manager, dst_fields_loc=changed_fields_loc)

l_flatten_func = functools.partial(_flatten_partition, block_num=data_manager.block_num)
l_flatten = l_df.block_table.mapPartitions(l_flatten_func, use_previous_behavior=False)
l_flatten = l_block_table.mapPartitions(l_flatten_func, use_previous_behavior=False)

for r_df in data_frames[1:]:
if l_df.schema != r_df.schema:
raise ValueError("Vstack of two dataframe with different schemas")

r_field_names = r_df.data_manager.get_field_name_list()
r_fields_loc = r_df.data_manager.get_fields_loc()
block_table = r_df.block_table
if l_fields_loc != r_fields_loc:
_align_func = functools.partial(_align_blocks, src_fields_loc=r_fields_loc, dm=data_manager)
block_table = block_table.mapValues(_align_func)
r_field_types = [data_manager.get_block(_bid).block_type for _bid, _ in r_fields_loc]
r_type_change = False if l_field_types != r_field_types else True
r_block_table = r_df.block_table
if l_field_names != r_field_names or r_type_change:
shuffle_r_fields_loc, full_migrate_set = [() for _ in range(len(r_field_names))], set()
for field_name, loc in zip(r_field_names, r_fields_loc):
l_offset = data_manager.get_field_offset(field_name)
shuffle_r_fields_loc[l_offset] = loc

for bid in range(r_df.data_manager.block_num):
r_field_indexes = r_df.data_manager.get_block(bid).field_indexes
field_indexes = [data_manager.get_field_offset(r_field_names[idx]) for idx in r_field_indexes]
l_bid = data_manager.loc_block(r_field_names[r_field_indexes[0]], with_offset=False)
if field_indexes == data_manager.get_block(l_bid).field_indexes:
full_migrate_set.add(bid)

_align_func = functools.partial(_align_blocks, align_fields_loc=shuffle_r_fields_loc,
full_block_migrate_set=full_migrate_set, dst_dm=data_manager)
r_block_table = r_block_table.mapValues(_align_func)

r_flatten_func = functools.partial(_flatten_partition, block_num=data_manager.block_num)
r_flatten = block_table.mapPartitions(r_flatten_func, use_previous_behavior=False)
r_flatten = r_block_table.mapPartitions(r_flatten_func, use_previous_behavior=False)
l_flatten = l_flatten.union(r_flatten)

# TODO: data-manager support align blocks first
# TODO: a fast way of vstack is just increase partition_id in r_df, then union,
# but data in every partition may be unbalance, so we use a more slow way by flatten data first

partition_order_mappings = get_partition_order_by_raw_table(l_flatten)
_convert_to_block_func = functools.partial(to_blocks,
dm=data_manager,
partition_mappings=partition_order_mappings)
_convert_to_block_func = functools.partial(to_blocks, dm=data_manager, partition_mappings=partition_order_mappings)
block_table = l_flatten.mapPartitions(_convert_to_block_func, use_previous_behavior=False)
block_table, data_manager = compress_blocks(block_table, data_manager)

block_table = l_flatten.mapPartitions(_convert_to_block_func,
use_previous_behavior=False)
return DataFrame(
l_df._ctx,
block_table,
Expand Down Expand Up @@ -256,7 +285,7 @@ def _flatten_partition(kvs, block_num=0):
return _flattens


def to_blocks(kvs, dm: DataManager=None, partition_mappings: dict=None):
def to_blocks(kvs, dm: DataManager = None, partition_mappings: dict = None):
ret_blocks = [[] for i in range(dm.block_num)]

partition_id = None
Expand Down
34 changes: 34 additions & 0 deletions python/fate/arch/dataframe/ops/_promote_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
import torch
from ..manager import DataManager
from ..manager.block_manager import Block
from typing import List, Tuple


def promote_types(block_table, data_manager: DataManager, to_promote_blocks):
Expand All @@ -30,3 +33,34 @@ def promote_types(block_table, data_manager: DataManager, to_promote_blocks):

return block_table, data_manager


def promote_partial_block_types(block_table, narrow_blocks, dst_blocks, dst_fields_loc,
data_manager: DataManager, inplace=True):
def _mapper(blocks, narrow_loc: list = None, dst_bids: list = None,
dst_loc: List[Tuple[str, str]] = None, dm: DataManager = None, inp: bool = True):
ret_blocks = []
for block in blocks:
if inp:
if isinstance(block, torch.Tensor):
ret_blocks.append(block.clone())
else:
ret_blocks.append(block.copy())
else:
ret_blocks.append(block)

for i in range(len(ret_blocks), dm.block_num):
ret_blocks.append([])

for bid, offsets in narrow_loc:
ret_blocks[bid] = ret_blocks[bid][:, offsets]

for dst_bid, (src_bid, src_offset) in zip(dst_bids, dst_loc):
block_values = blocks[src_bid][:, [src_offset]]
ret_blocks[dst_bid] = dm.blocks[dst_bid].convert_block(block_values)

return ret_blocks

_mapper_func = functools.partial(_mapper, narrow_loc=narrow_blocks, dst_bids=dst_blocks,
dst_loc=dst_fields_loc, dm=data_manager, inp=inplace)

return block_table.mapValues(_mapper_func)

0 comments on commit a7d6e50

Please sign in to comment.