diff --git a/python/fate/arch/dataframe/_frame_reader.py b/python/fate/arch/dataframe/_frame_reader.py index 26baf8a928..cec673cc88 100644 --- a/python/fate/arch/dataframe/_frame_reader.py +++ b/python/fate/arch/dataframe/_frame_reader.py @@ -86,12 +86,13 @@ def _dense_format_to_frame(self, ctx, table): from .ops._indexer import get_partition_order_by_raw_table partition_order_mappings = get_partition_order_by_raw_table(table) # partition_order_mappings = _get_partition_order(table) - functools.partial(_to_blocks, - data_manager=data_manager, - retrieval_index_dict=retrieval_index_dict, - partition_order_mappings=partition_order_mappings) + table = table.mapValues(lambda value: value.split(self._delimiter, -1)) + to_block_func = functools.partial(_to_blocks, + data_manager=data_manager, + retrieval_index_dict=retrieval_index_dict, + partition_order_mappings=partition_order_mappings) block_table = table.mapPartitions( - _to_blocks, + to_block_func, use_previous_behavior=False ) diff --git a/python/fate/arch/dataframe/manager/block_manager.py b/python/fate/arch/dataframe/manager/block_manager.py index 763c42df37..50a96bf7a6 100644 --- a/python/fate/arch/dataframe/manager/block_manager.py +++ b/python/fate/arch/dataframe/manager/block_manager.py @@ -220,7 +220,10 @@ def __init__(self, *args, **kwargs): @staticmethod def convert_block(block): - return torch.tensor(block, dtype=torch.int32) + try: + return torch.tensor(block, dtype=torch.int32) + except ValueError: + return torch.tensor(np.array(block, dtype="int32"), dtype=torch.int32) class Int64Block(Block): @@ -230,7 +233,10 @@ def __init__(self, *args, **kwargs): @staticmethod def convert_block(block): - return torch.tensor(block, dtype=torch.int64) + try: + return torch.tensor(block, dtype=torch.int64) + except ValueError: + return torch.tensor(np.array(block, dtype="int64"), dtype=torch.int64) class Float32Block(Block): @@ -240,7 +246,10 @@ def __init__(self, *args, **kwargs): @staticmethod def convert_block(block): - return torch.tensor(block, dtype=torch.float32) + try: + return torch.tensor(block, dtype=torch.float32) + except ValueError: + return torch.tensor(np.array(block, dtype="float32"), dtype=torch.float32) class Float64Block(Block): @@ -250,7 +259,10 @@ def __init__(self, *args, **kwargs): @staticmethod def convert_block(block): - return torch.tensor(block, dtype=torch.float64) + try: + return torch.tensor(block, dtype=torch.float64) + except ValueError: + return torch.tensor(np.array(block, dtype="float64"), dtype=torch.float64) class BoolBlock(Block): @@ -260,7 +272,10 @@ def __init__(self, *args, **kwargs): @staticmethod def convert_block(block): - return torch.tensor(block, dtype=torch.bool) + try: + return torch.tensor(block, dtype=torch.bool) + except ValueError: + return torch.tensor(np.array(block, dtype="bool"), dtype=torch.bool) class IndexBlock(Block): diff --git a/python/fate/components/components/dataframe_transformer.py b/python/fate/components/components/dataframe_transformer.py index 5b73f31ee6..01b38e32ec 100644 --- a/python/fate/components/components/dataframe_transformer.py +++ b/python/fate/components/components/dataframe_transformer.py @@ -14,14 +14,13 @@ # limitations under the License. # from fate.components.core import LOCAL, Role, cpn -from typing import Union, List, Dict @cpn.component(roles=[LOCAL]) @cpn.table_input("table", roles=[LOCAL]) @cpn.dataframe_output("dataframe_output", roles=[LOCAL]) -@cpn.parameter("namespace", type=str, default=",", optional=True) -@cpn.parameter("name", type=str, default="dense", optional=True) +@cpn.parameter("namespace", type=str, default=None, optional=True) +@cpn.parameter("name", type=str, default=None, optional=True) @cpn.parameter("anonymous_role", type=str, default=None, optional=True) @cpn.parameter("anonymous_party_id", type=str, default=None, optional=True) def dataframe_transformer( @@ -35,6 +34,7 @@ def dataframe_transformer( anonymous_party_id, ): from fate.arch.dataframe import TableReader + metadata = table.schema table_reader = TableReader( sample_id_name=metadata.get("sample_id_name", None), @@ -57,4 +57,4 @@ def dataframe_transformer( ) df = table_reader.to_frame(ctx, table) - ctx.writer(dataframe_output).write(df, namespace=namespace, name=name) + dataframe_output.write(ctx, df, name=name, namespace=namespace)