Skip to content

Commit

Permalink
fix dataframe convert when input cols are str
Browse files Browse the repository at this point in the history
Signed-off-by: mgqa34 <mgq3374541@163.com>
  • Loading branch information
mgqa34 committed Jun 20, 2023
1 parent 3b5923b commit b115c34
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
10 changes: 5 additions & 5 deletions python/fate/arch/dataframe/_frame_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ def _dense_format_to_frame(self, ctx, table):
from .ops._indexer import get_partition_order_by_raw_table
partition_order_mappings = get_partition_order_by_raw_table(table)
# partition_order_mappings = _get_partition_order(table)
functools.partial(_to_blocks,
data_manager=data_manager,
retrieval_index_dict=retrieval_index_dict,
partition_order_mappings=partition_order_mappings)
to_block_func = functools.partial(_to_blocks,
data_manager=data_manager,
retrieval_index_dict=retrieval_index_dict,
partition_order_mappings=partition_order_mappings)
block_table = table.mapPartitions(
_to_blocks,
to_block_func,
use_previous_behavior=False
)

Expand Down
25 changes: 20 additions & 5 deletions python/fate/arch/dataframe/manager/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,10 @@ def __init__(self, *args, **kwargs):

@staticmethod
def convert_block(block):
return torch.tensor(block, dtype=torch.int32)
try:
return torch.tensor(block, dtype=torch.int32)
except ValueError:
return torch.tensor(np.array(block, dtype="int32"), dtype=torch.int32)


class Int64Block(Block):
Expand All @@ -230,7 +233,10 @@ def __init__(self, *args, **kwargs):

@staticmethod
def convert_block(block):
return torch.tensor(block, dtype=torch.int64)
try:
return torch.tensor(block, dtype=torch.int64)
except ValueError:
return torch.tensor(np.array(block, dtype="int64"), dtype=torch.int64)


class Float32Block(Block):
Expand All @@ -240,7 +246,10 @@ def __init__(self, *args, **kwargs):

@staticmethod
def convert_block(block):
return torch.tensor(block, dtype=torch.float32)
try:
return torch.tensor(block, dtype=torch.float32)
except ValueError:
return torch.tensor(np.array(block, dtype="float32"), dtype=torch.float32)


class Float64Block(Block):
Expand All @@ -250,7 +259,10 @@ def __init__(self, *args, **kwargs):

@staticmethod
def convert_block(block):
return torch.tensor(block, dtype=torch.float64)
try:
return torch.tensor(block, dtype=torch.float64)
except ValueError:
return torch.tensor(np.array(block, dtype="float64"), dtype=torch.float64)


class BoolBlock(Block):
Expand All @@ -260,7 +272,10 @@ def __init__(self, *args, **kwargs):

@staticmethod
def convert_block(block):
return torch.tensor(block, dtype=torch.bool)
try:
return torch.tensor(block, dtype=torch.bool)
except ValueError:
return torch.tensor(np.array(block, dtype="bool"), dtype=torch.bool)


class IndexBlock(Block):
Expand Down

0 comments on commit b115c34

Please sign in to comment.