Skip to content

Commit

Permalink
Merge branch 'dev-2.0.0-beta' of https://github.com/FederatedAI/FATE
Browse files Browse the repository at this point in the history
…into feature-2.0.0-beta-fate-test
  • Loading branch information
nemirorox committed Sep 7, 2023
2 parents 3d4b07c + 5a43560 commit 016d1c5
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 6 deletions.
27 changes: 25 additions & 2 deletions python/fate/arch/dataframe/manager/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,13 @@ def __init__(self, *args, **kwargs):
def convert_block(block):
if isinstance(block, torch.Tensor):
if block.dtype == torch.int32:
return block.clone().detach()
return block
else:
return block.to(torch.int32)
return torch.tensor(np.array(block, dtype="int32"), dtype=torch.int32)
try:
return torch.tensor(block, dtype=torch.int32)
except ValueError:
return torch.tensor(np.array(block, dtype="int32"), dtype=torch.int32)


class Int64Block(Block):
Expand All @@ -307,6 +310,11 @@ def __init__(self, *args, **kwargs):

@staticmethod
def convert_block(block):
if isinstance(block, torch.Tensor):
if block.dtype == torch.int64:
return block
else:
return block.to(torch.int64)
try:
return torch.tensor(block, dtype=torch.int64)
except ValueError:
Expand All @@ -320,6 +328,11 @@ def __init__(self, *args, **kwargs):

@staticmethod
def convert_block(block):
if isinstance(block, torch.Tensor):
if block.dtype == torch.float32:
return block
else:
return block.to(torch.float32)
try:
return torch.tensor(block, dtype=torch.float32)
except ValueError:
Expand All @@ -333,6 +346,11 @@ def __init__(self, *args, **kwargs):

@staticmethod
def convert_block(block):
if isinstance(block, torch.Tensor):
if block.dtype == torch.float64:
return block
else:
return block.to(torch.float64)
try:
return torch.tensor(block, dtype=torch.float64)
except ValueError:
Expand All @@ -346,6 +364,11 @@ def __init__(self, *args, **kwargs):

@staticmethod
def convert_block(block):
if isinstance(block, torch.Tensor):
if block.dtype == torch.bool:
return block
else:
return block.to(torch.bool)
try:
return torch.tensor(block, dtype=torch.bool)
except ValueError:
Expand Down
2 changes: 1 addition & 1 deletion python/fate/arch/dataframe/ops/_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _mapper(
if isinstance(blocks[src_bid], torch.Tensor):
ret = torch.bucketize(blocks[src_bid][:, [src_offset]], boundary, out_int32=False)
else:
ret = torch.bucketize(blocks[src_bid][:, [src_offset]], boundary)
ret = np.digitize(blocks[src_bid][:, [src_offset]], boundary)

ret_blocks[dst_bid] = dm.blocks[dst_bid].convert_block(ret)

Expand Down
4 changes: 2 additions & 2 deletions python/fate/arch/dataframe/utils/_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _sample_guest(
regenerated_sample_id_prefix = generate_sample_id_prefix()
choice_with_regenerated_ids = None
for label, f in frac.items():
label_df = df[(df.label == label).as_tensor()]
label_df = df.iloc(df.label == label)
label_n = max(1, int(label_df.shape[0] * f))
choices = resample(list(range(label_df.shape[0])), replace=True,
n_samples=label_n, random_state=random_state)
Expand All @@ -139,7 +139,7 @@ def _sample_guest(
else:
sample_df = None
for label, f in frac.items():
label_df = df[(df.label == label).as_tensor()]
label_df = df.iloc(df.label == label)
label_n = max(1, int(label_df.shape[0] * f))
sample_label_df = label_df.sample(n=label_n, random_state=random_state)

Expand Down
2 changes: 1 addition & 1 deletion python/fate/ml/model_selection/data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def sample_per_label(train_data, sample_count=None, random_state=None):
sampled_n = 0
data_n = train_data.shape[0]
for i, label in enumerate(labels):
label_data = train_data[(train_data.label == int(label)).as_tensor()]
label_data = train_data.iloc(train_data.label == int(label))
if i == len(labels) - 1:
# last label:
to_sample_n = sample_count - sampled_n
Expand Down

0 comments on commit 016d1c5

Please sign in to comment.