Skip to content

Commit

Permalink
Merge branch 'dev-2.0.0-beta' of https://github.com/FederatedAI/FATE
Browse files Browse the repository at this point in the history
…into feature-2.0.0-beta-fate-test
  • Loading branch information
nemirorox committed Sep 8, 2023
2 parents baae180 + 632440e commit 01c3741
Show file tree
Hide file tree
Showing 15 changed files with 2,076 additions and 321 deletions.
52 changes: 39 additions & 13 deletions python/fate/arch/context/_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,36 +57,62 @@ def setup(self, options):

sk, pk, coder = keygen(key_size)
tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator)
return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher)

# if kind == "heu":
# from fate.arch.protocol.phe.heu import evaluator, keygen
# from fate.arch.tensor.phe import PHETensorCipher
#
# sk, pk, coder = keygen(key_size)
# tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator)
# return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher)
# #
return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher, True, True, True)

if kind == "ou":
from fate.arch.protocol.phe.ou import evaluator, keygen
from fate.arch.tensor.phe import PHETensorCipher

sk, pk, coder = keygen(key_size)
tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator)
return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher, False, False, True)

elif kind == "mock":
# from fate.arch.protocol.phe.mock import evaluator, keygen
from fate.arch.protocol.phe.mock import evaluator, keygen
from fate.arch.tensor.phe import PHETensorCipher

sk, pk, coder = keygen(key_size)
tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator)
return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher)
return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher, True, False, False)

else:
raise ValueError(f"Unknown PHE keygen kind: {self.kind}")


class PHECipher:
def __init__(self, key_size, pk, sk, evaluator, coder, tensor_cipher) -> None:
def __init__(
self,
key_size,
pk,
sk,
evaluator,
coder,
tensor_cipher,
can_support_negative_number,
can_support_squeeze,
can_support_pack,
) -> None:
self._key_size = key_size
self._pk = pk
self._sk = sk
self._coder = coder
self._evaluator = evaluator
self._tensor_cipher = tensor_cipher
self._can_support_negative_number = can_support_negative_number
self._can_support_squeeze = can_support_squeeze
self._can_support_pack = can_support_pack

@property
def can_support_negative_number(self):
return self._tensor_cipher.can_support_negative_number

@property
def can_support_squeeze(self):
return self._tensor_cipher.can_support_squeeze

@property
def can_support_pack(self):
return self._tensor_cipher.can_support_pack

@property
def key_size(self):
Expand Down
9 changes: 6 additions & 3 deletions python/fate/arch/histogram/_histogram_sbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class HistogramBuilder:
def __init__(
self, num_node, feature_bin_sizes, value_schemas, global_seed=None, seed=None, node_mapping=None, k=None
self, num_node, feature_bin_sizes, value_schemas, global_seed=None, seed=None, node_mapping=None, k=None, enable_cumsum=True
):
self._num_node = num_node
self._feature_bin_sizes = feature_bin_sizes
Expand All @@ -13,6 +13,7 @@ def __init__(
self._global_seed = global_seed
self._seed = seed
self._node_mapping = node_mapping
self._enable_cumsum = enable_cumsum
self._k = k

def __str__(self):
Expand All @@ -35,6 +36,7 @@ def statistic(self, data) -> "DistributedHistogram":
self._global_seed,
self._k,
self._node_mapping,
self._enable_cumsum,
)
table = data.mapReducePartitions(mapper, lambda x, y: x.iadd(y))
data = DistributedHistogram(
Expand All @@ -43,13 +45,14 @@ def statistic(self, data) -> "DistributedHistogram":
return data


def get_partition_hist_build_mapper(num_node, feature_bin_sizes, value_schemas, global_seed, k, node_mapping):
def get_partition_hist_build_mapper(num_node, feature_bin_sizes, value_schemas, global_seed, k, node_mapping, enable_cumsum):
def _partition_hist_build_mapper(part):
hist = Histogram.create(num_node, feature_bin_sizes, value_schemas)
for _, raw in part:
feature_ids, node_ids, targets = raw
hist.i_update(feature_ids, node_ids, targets, node_mapping)
hist.i_cumsum_bins()
if enable_cumsum:
hist.i_cumsum_bins()
if global_seed is not None:
hist.i_shuffle(global_seed)
splits = hist.to_splits(k)
Expand Down
Loading

0 comments on commit 01c3741

Please sign in to comment.