Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev 2.0.0 beta fix mock #5116

Merged
merged 3 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions python/fate/arch/context/_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,6 @@ def setup(self, options):
kind = options.get("kind", self.kind)
key_size = options.get("key_length", 1024)

if kind == "paillier_old":
import fate_utils
from fate.arch.tensor.paillier import PaillierTensorCipher

pk, sk = fate_utils.tensor.keygen(key_size)
tensor_cipher = PaillierTensorCipher.from_raw_cipher(pk, None, sk)
return PHECipher(key_size, pk, sk, None, None, tensor_cipher)

if kind == "paillier":
from fate.arch.protocol.phe.paillier import evaluator, keygen
from fate.arch.tensor.phe import PHETensorCipher
Expand All @@ -67,20 +59,22 @@ def setup(self, options):
tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator)
return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher)

if kind == "heu":
from fate.arch.protocol.phe.heu import evaluator, keygen
# if kind == "heu":
# from fate.arch.protocol.phe.heu import evaluator, keygen
# from fate.arch.tensor.phe import PHETensorCipher
#
# sk, pk, coder = keygen(key_size)
# tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator)
# return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher)
# #
elif kind == "mock":
from fate.arch.protocol.phe.mock import evaluator, keygen
from fate.arch.tensor.phe import PHETensorCipher

sk, pk, coder = keygen(key_size)
tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator)
return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher)

elif kind == "mock":
from fate.arch.tensor.mock import PaillierTensorCipher

tensor_cipher = PaillierTensorCipher(**options)
return PHECipher(key_size, None, None, None, None, tensor_cipher)

else:
raise ValueError(f"Unknown PHE keygen kind: {self.kind}")

Expand All @@ -93,7 +87,7 @@ def __init__(self, key_size, pk, sk, evaluator, coder, tensor_cipher) -> None:
self._coder = coder
self._evaluator = evaluator
self._tensor_cipher = tensor_cipher

@property
def key_size(self):
return self._key_size
Expand Down
49 changes: 22 additions & 27 deletions python/fate/arch/dataframe/manager/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import pandas as pd
import torch
from fate.arch.tensor.phe._tensor import PHETensor
from fate_utils.paillier import CiphertextVector

from .schema_manager import SchemaManager

Expand Down Expand Up @@ -245,33 +244,30 @@ def convert_block_type(self, block_type):

return converted_block

@classmethod
def retrieval_row(cls, block, indexes):
if isinstance(block, CiphertextVector):
return block.slice_indexes(indexes)
elif isinstance(block, pd.Index):
if isinstance(indexes, list):
return block[indexes]
else:
return pd.Index(block[indexes])
else:
return block[indexes]
# @classmethod
# def retrieval_row(cls, block, indexes):
# if isinstance(block, CiphertextVector):
# return block.slice_indexes(indexes)
# elif isinstance(block, pd.Index):
# if isinstance(indexes, list):
# return block[indexes]
# else:
# return pd.Index(block[indexes])
# else:
# return block[indexes]

@classmethod
def transform_block_to_list(cls, block):
if isinstance(block, CiphertextVector):
return [block.slice_indexes([i]) for i in range(len(block))]
else:
return block.tolist()
return block.tolist()

@classmethod
def transform_row_to_raw(cls, block, index):
if isinstance(block, pd.Index):
return block[index]
elif isinstance(block, CiphertextVector):
return block.slice_indexes([index])
else:
return block[index].tolist()
# @classmethod
# def transform_row_to_raw(cls, block, index):
# if isinstance(block, pd.Index):
# return block[index]
# elif isinstance(block, CiphertextVector):
# return block.slice_indexes([index])
# else:
# return block[index].tolist()

@classmethod
def vstack(cls, blocks):
Expand Down Expand Up @@ -385,10 +381,9 @@ def set_extra_kwargs(self, pk, evaluator, coder, dtype, device):
self._dtype = dtype
self._device = device

@staticmethod
def convert_block(block):
def convert_block(self, block):
if isinstance(block, list):
block = block[0].cat(block[1:])
block = self._evaluator.cat(block)
return block

def convert_to_phe_tensor(self, block, shape):
Expand Down
Loading