diff --git a/python/fate/arch/context/_cipher.py b/python/fate/arch/context/_cipher.py index c51807b3f9..d85006e7fc 100644 --- a/python/fate/arch/context/_cipher.py +++ b/python/fate/arch/context/_cipher.py @@ -57,36 +57,62 @@ def setup(self, options): sk, pk, coder = keygen(key_size) tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator) - return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher) - - # if kind == "heu": - # from fate.arch.protocol.phe.heu import evaluator, keygen - # from fate.arch.tensor.phe import PHETensorCipher - # - # sk, pk, coder = keygen(key_size) - # tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator) - # return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher) - # # + return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher, True, True, True) + + if kind == "ou": + from fate.arch.protocol.phe.ou import evaluator, keygen + from fate.arch.tensor.phe import PHETensorCipher + + sk, pk, coder = keygen(key_size) + tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator) + return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher, False, False, True) + elif kind == "mock": - # from fate.arch.protocol.phe.mock import evaluator, keygen + from fate.arch.protocol.phe.mock import evaluator, keygen from fate.arch.tensor.phe import PHETensorCipher sk, pk, coder = keygen(key_size) tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator) - return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher) + return PHECipher(key_size, pk, sk, evaluator, coder, tensor_cipher, True, False, False) else: raise ValueError(f"Unknown PHE keygen kind: {self.kind}") class PHECipher: - def __init__(self, key_size, pk, sk, evaluator, coder, tensor_cipher) -> None: + def __init__( + self, + key_size, + pk, + sk, + evaluator, + coder, + tensor_cipher, + can_support_negative_number, + can_support_squeeze, + can_support_pack, + ) -> None: self._key_size = key_size self._pk = pk self._sk = sk self._coder = coder self._evaluator = evaluator self._tensor_cipher = tensor_cipher + self._can_support_negative_number = can_support_negative_number + self._can_support_squeeze = can_support_squeeze + self._can_support_pack = can_support_pack + + @property + def can_support_negative_number(self): + return self._tensor_cipher.can_support_negative_number + + @property + def can_support_squeeze(self): + return self._tensor_cipher.can_support_squeeze + + @property + def can_support_pack(self): + return self._tensor_cipher.can_support_pack @property def key_size(self): diff --git a/python/fate/arch/histogram/_histogram_sbt.py b/python/fate/arch/histogram/_histogram_sbt.py index e45c2f6a49..24b12a9686 100644 --- a/python/fate/arch/histogram/_histogram_sbt.py +++ b/python/fate/arch/histogram/_histogram_sbt.py @@ -4,7 +4,7 @@ class HistogramBuilder: def __init__( - self, num_node, feature_bin_sizes, value_schemas, global_seed=None, seed=None, node_mapping=None, k=None + self, num_node, feature_bin_sizes, value_schemas, global_seed=None, seed=None, node_mapping=None, k=None, enable_cumsum=True ): self._num_node = num_node self._feature_bin_sizes = feature_bin_sizes @@ -13,6 +13,7 @@ def __init__( self._global_seed = global_seed self._seed = seed self._node_mapping = node_mapping + self._enable_cumsum = enable_cumsum self._k = k def __str__(self): @@ -35,6 +36,7 @@ def statistic(self, data) -> "DistributedHistogram": self._global_seed, self._k, self._node_mapping, + self._enable_cumsum, ) table = data.mapReducePartitions(mapper, lambda x, y: x.iadd(y)) data = DistributedHistogram( @@ -43,13 +45,14 @@ def statistic(self, data) -> "DistributedHistogram": return data -def get_partition_hist_build_mapper(num_node, feature_bin_sizes, value_schemas, global_seed, k, node_mapping): +def get_partition_hist_build_mapper(num_node, feature_bin_sizes, value_schemas, global_seed, k, node_mapping, enable_cumsum): def _partition_hist_build_mapper(part): hist = Histogram.create(num_node, feature_bin_sizes, value_schemas) for _, raw in part: feature_ids, node_ids, targets = raw hist.i_update(feature_ids, node_ids, targets, node_mapping) - hist.i_cumsum_bins() + if enable_cumsum: + hist.i_cumsum_bins() if global_seed is not None: hist.i_shuffle(global_seed) splits = hist.to_splits(k) diff --git a/python/fate/arch/protocol/phe/heu.py b/python/fate/arch/protocol/phe/heu.py deleted file mode 100644 index 1e3bfeb950..0000000000 --- a/python/fate/arch/protocol/phe/heu.py +++ /dev/null @@ -1,303 +0,0 @@ -from typing import List, Optional, Tuple - -import numpy as np -import torch -from heu import numpy as hnp -from heu import phe - -V = torch.Tensor -EV = "FixedpointPaillierVector" -FV = "FixedpointVector" - - -class SK: - def __init__(self, sk: hnp.HeKit) -> None: - self.sk = sk.decryptor() - - def decrypt_to_encoded(self, vec: EV) -> FV: - return self.sk.decrypt(vec) - - -class PK: - def __init__(self, kit: hnp.HeKit): - self.kit = kit - self.encryptor = kit.encryptor() - - def encrypt_encoded(self, vec: FV, obfuscate: bool) -> EV: - return self.encryptor.encrypt(vec) - - def encrypt_encoded_scalar(self, val, obfuscate) -> EV: - return self.encryptor.encrypt(val) - - -class Coder: - def __init__(self, kit: hnp.HeKit): - self.kit = kit - self.float_encoder = kit.float_encoder() - self.int_encoder = kit.integer_encoder() - - def encode_tensor(self, tensor: V, dtype: torch.dtype = None) -> FV: - if dtype is None: - dtype = tensor.dtype - if dtype == torch.float64: - return self.kit.array(tensor.detach().numpy(), self.float_encoder) - if dtype == torch.float32: - return self.kit.array(tensor.detach().numpy(), self.float_encoder) - if dtype == torch.int64: - return self.kit.array(tensor.detach().numpy(), self.int_encoder) - if dtype == torch.int32: - return self.kit.array(tensor.detach().numpy(), self.int_encoder) - raise NotImplementedError(f"{dtype} not supported") - - def decode_tensor(self, tensor: FV, dtype: torch.dtype, shape: torch.Size = None) -> V: - if dtype == torch.float64: - data = torch.tensor(tensor.to_numpy(self.float_encoder)).type(dtype) - elif dtype == torch.float32: - data = torch.tensor(tensor.to_numpy(self.float_encoder)).type(dtype) - elif dtype == torch.int64: - data = torch.tensor(tensor.to_numpy(self.int_encoder)).type(dtype) - elif dtype == torch.int32: - data = torch.tensor(tensor.to_numpy(self.int_encoder)).type(dtype) - else: - raise NotImplementedError(f"{dtype} not supported") - if shape is not None: - data = data.reshape(shape) - return data - - def encode_vec(self, vec: V, dtype: torch.dtype = None) -> FV: - if dtype is None: - dtype = vec.dtype - if dtype == torch.float64: - return self.encode_f64_vec(vec) - if dtype == torch.float32: - return self.encode_f32_vec(vec) - if dtype == torch.int64: - return self.encode_i64_vec(vec) - if dtype == torch.int32: - return self.encode_i32_vec(vec) - raise NotImplementedError(f"{vec.dtype} not supported") - - def decode_vec(self, vec: FV, dtype: torch.dtype) -> V: - if dtype == torch.float64: - return self.decode_f64_vec(vec) - if dtype == torch.float32: - return self.decode_f32_vec(vec) - if dtype == torch.int64: - return self.decode_i64_vec(vec) - if dtype == torch.int32: - return self.decode_i32_vec(vec) - raise NotImplementedError(f"{dtype} not supported") - - def encode(self, val, dtype=None) -> FV: - if isinstance(val, torch.Tensor): - assert val.ndim == 0, "only scalar supported" - if dtype is None: - dtype = val.dtype - val = val.item() - if dtype == torch.float64: - return self.encode_f64(val) - if dtype == torch.float32: - return self.encode_f32(val) - if dtype == torch.int64: - return self.encode_i64(val) - if dtype == torch.int32: - return self.encode_i32(val) - raise NotImplementedError(f"{dtype} not supported") - - def encode_f64(self, val: float): - return self.kit.array(val, self.float_encoder) - - def decode_f64(self, val): - return torch.tensor(val.to_numpy(self.float_encoder)).type(torch.float64) - - def encode_i64(self, val: int): - return self.kit.array(val, self.float_encoder) - - def decode_i64(self, val): - return torch.tensor(val.to_numpy(self.float_encoder)).type(torch.int64) - - def encode_f32(self, val: float): - return self.kit.array(val, self.float_encoder) - - def decode_f32(self, val): - return torch.tensor(val.to_numpy(self.float_encoder)).type(torch.float32) - - def encode_i32(self, val: int): - return self.kit.array(val, self.int_encoder) - - def decode_i32(self, val): - return torch.tensor(val.to_numpy(self.int_encoder)).type(torch.int32) - - def encode_f64_vec(self, vec: torch.Tensor): - return self.kit.array(vec.detach().numpy(), self.float_encoder) - - def decode_f64_vec(self, vec): - return torch.tensor(vec.to_numpy(self.float_encoder)).type(torch.float64) - - def encode_i64_vec(self, vec: torch.Tensor): - return self.kit.array(vec.detach().numpy(), self.int_encoder) - - def decode_i64_vec(self, vec): - return torch.tensor(vec.to_numpy(self.int_encoder)).type(torch.int64) - - def encode_f32_vec(self, vec: torch.Tensor): - return self.kit.array(vec.detach().numpy(), self.float_encoder) - - def decode_f32_vec(self, vec): - return torch.tensor(vec.to_numpy(self.float_encoder)).type(torch.float32) - - def encode_i32_vec(self, vec: torch.Tensor): - return self.kit.array(vec.detach().numpy(), self.int_encoder) - - def decode_i32_vec(self, vec): - return torch.tensor(vec.to_numpy(self.int_encoder)).type(torch.int32) - - -def keygen(key_size): - phe_kit = phe.setup(phe.SchemaType.ZPaillier, key_size) - kit = hnp.HeKit(phe_kit) - pub_kit = hnp.setup(kit.public_key()) - return SK(kit), PK(pub_kit), Coder(pub_kit) - - -class evaluator: - @staticmethod - def add(a: EV, b: EV, pk: PK): - return pk.kit.evaluator().add(a, b) - - @staticmethod - def add_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): - if output_dtype is None: - output_dtype = b.dtype - encoded = coder.encode_vec(b, dtype=output_dtype) - encrypted = pk.encrypt_encoded(encoded, obfuscate=False) - return pk.kit.evaluator().add(a, encrypted) - - @staticmethod - def add_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): - encoded = coder.encode(b, dtype=output_dtype) - encrypted = pk.encrypt_encoded_scalar(encoded, obfuscate=False) - return pk.kit.evaluator().add(a, encrypted) - - @staticmethod - def sub(a: EV, b: EV, pk: PK): - return pk.kit.evaluator().sub(a, b) - - @staticmethod - def sub_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): - if output_dtype is None: - output_dtype = b.dtype - encoded = coder.encode_vec(b, dtype=output_dtype) - encrypted = pk.encrypt_encoded(encoded, obfuscate=False) - return pk.kit.evaluator().sub(a, encrypted) - - @staticmethod - def sub_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): - encoded = coder.encode(b, dtype=output_dtype) - encrypted = pk.encrypt_encoded_scalar(encoded, obfuscate=False) - return pk.kit.evaluator().sub(a, encrypted) - - @staticmethod - def rsub(a: EV, b: EV, pk: PK): - return evaluator.sub(b, a, pk) - - @staticmethod - def rsub_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): - if output_dtype is None: - output_dtype = b.dtype - encoded = coder.encode_vec(b, dtype=output_dtype) - encrypted = pk.encrypt_encoded(encoded, obfuscate=False) - return evaluator.rsub(a, encrypted, pk) - - @staticmethod - def rsub_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): - encoded = coder.encode(b, dtype=output_dtype) - encrypted = pk.encrypt_encoded_scalar(encoded, obfuscate=False) - return evaluator.rsub(a, encrypted, pk) - - @staticmethod - def mul_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): - if output_dtype is None: - output_dtype = b.dtype - encoded = coder.encode_vec(b, dtype=output_dtype) - return pk.kit.evaluator().mul(a, encoded) - - @staticmethod - def mul_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): - encoded = coder.encode(b, dtype=output_dtype) - return pk.kit.evaluator().mul(a, encoded) - - @staticmethod - def matmul(a: EV, b: V, a_shape, b_shape, pk: PK, coder: Coder, output_dtype): - encoded = coder.encode_vec(b.reshape(b_shape), dtype=output_dtype) - # TODO: move this to python side so other protocols can use it without matmul support? - return pk.kit.evaluator().matmul(a, encoded) - - @staticmethod - def rmatmul(a: EV, b: V, a_shape, b_shape, pk: PK, coder: Coder, output_dtype): - encoded = coder.encode_vec(b, dtype=output_dtype) - return pk.kit.evaluator().matmul(encoded, a) - - @staticmethod - def zeros(pk: PK, size) -> EV: - return pk.encryptor.encrypt(pk.kit.array(np.zeros(size, np.int), pk.kit.integer_encoder())) - - @staticmethod - def i_add(pk: PK, a: EV, b: EV, sa=0, sb=0, size: Optional[int] = None) -> None: - """ - inplace add, a[sa:sa+size] += b[sb:sb+size], if size is None, then size = min(a.size - sa, b.size - sb) - Args: - pk: the public key - a: the vector to add to - b: the vector to add - sa: the start index of a - sb: the start index of b - size: the size to add - """ - if a is b: - a.iadd_vec_self(sa, sb, size, pk.pk) - else: - a.iadd_vec(b, sa, sb, size, pk.pk) - - @staticmethod - def slice(a: EV, start: int, size: int) -> EV: - """ - slice a[start:start+size] - Args: - a: the vector to slice - start: the start index - size: the size to slice - - Returns: - the sliced vector - """ - return a.slice(start, size) - - @staticmethod - def intervals_sum_with_step(pk: PK, a: EV, intervals: List[Tuple[int, int]], step: int): - """ - sum in the given intervals, with step size - - for example: - if step=2, intervals=[(0, 4), (6, 12)], a = [a0, a1, a2, a3, a4, a5, a6, a7,...] - then the result is [a0+a2, a1+a3, a6+a8+a10, a7+a9+a11] - """ - return a.intervals_sum_with_step(pk.pk, intervals, step) - - @staticmethod - def chunking_cumsum_with_step(pk: PK, a: EV, chunk_sizes: List[int], step: int): - """ - chunking cumsum with step size - - for example: - if step=2, chunk_sizes=[4, 2, 6], a = [a0, a1, a2, a3, a4, a5, a6, a7,...a11] - then the result is [a0, a1, a0+a2, a1+a3, a4, a5, a6, a7, a6+a8, a7+a9, a6+a8+a10, a7+a9+a11] - Args: - pk: the public key - a: the vector to cumsum - chunk_sizes: the chunk sizes, must sum to a.size - step: the step size, cumsum with skip step-1 elements - Returns: - the cumsum result - """ - return a.chunking_cumsum_with_step(pk.pk, chunk_sizes, step) diff --git a/python/fate/arch/protocol/phe/mock.py b/python/fate/arch/protocol/phe/mock.py index 625bf7543b..c716f66b8b 100644 --- a/python/fate/arch/protocol/phe/mock.py +++ b/python/fate/arch/protocol/phe/mock.py @@ -236,14 +236,28 @@ def mul_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): @staticmethod def matmul(a: EV, b: V, a_shape, b_shape, pk: PK, coder: Coder, output_dtype): - data = torch.matmul(a.data.reshape(a_shape), b.data.reshape(b_shape)).flatten() + left = a.data.reshape(a_shape) + right = b.data.reshape(b_shape) + target_type = torch.promote_types(a.data.dtype, b.data.dtype) + if left.dtype != target_type: + left = left.to(dtype=target_type) + if right.dtype != target_type: + right = right.to(dtype=target_type) + data = torch.matmul(left, right).flatten() if output_dtype is not None: data = data.to(dtype=output_dtype) return EV(data) @staticmethod def rmatmul(a: EV, b: V, a_shape, b_shape, pk: PK, coder: Coder, output_dtype): - data = torch.matmul(b.data.reshape(b_shape), a.data.reshape(a_shape)).flatten() + right = a.data.reshape(a_shape) + left = b.data.reshape(b_shape) + target_type = torch.promote_types(a.data.dtype, b.data.dtype) + if left.dtype != target_type: + left = left.to(dtype=target_type) + if right.dtype != target_type: + right = right.to(dtype=target_type) + data = torch.matmul(left, right).flatten() if output_dtype is not None: data = data.to(dtype=output_dtype) return EV(data) diff --git a/python/fate/arch/protocol/phe/ou.py b/python/fate/arch/protocol/phe/ou.py new file mode 100644 index 0000000000..81e0fc70e5 --- /dev/null +++ b/python/fate/arch/protocol/phe/ou.py @@ -0,0 +1,415 @@ +from typing import List, Optional, Tuple + +import torch +from fate_utils.ou import PK as _PK +from fate_utils.ou import SK as _SK +from fate_utils.ou import Coder as _Coder +from fate_utils.ou import Evaluator as _Evaluator +from fate_utils.ou import CiphertextVector, PlaintextVector +from fate_utils.ou import keygen as _keygen + +from .type import TensorEvaluator + +V = torch.Tensor +EV = CiphertextVector +FV = PlaintextVector + + +class SK: + def __init__(self, sk: _SK): + self.sk = sk + + def decrypt_to_encoded(self, vec: EV) -> FV: + return self.sk.decrypt_to_encoded(vec) + + +class PK: + def __init__(self, pk: _PK): + self.pk = pk + + def encrypt_encoded(self, vec: FV, obfuscate: bool) -> EV: + return self.pk.encrypt_encoded(vec, obfuscate) + + def encrypt_encoded_scalar(self, val, obfuscate) -> EV: + return self.pk.encrypt_encoded_scalar(val, obfuscate) + + +class Coder: + def __init__(self, coder: _Coder): + self.coder = coder + + def pack_floats(self, float_tensor: V, offset_bit: int, pack_num: int, precision: int) -> FV: + return self.coder.pack_floats(float_tensor.detach().tolist(), offset_bit, pack_num, precision) + + def unpack_floats(self, packed: FV, offset_bit: int, pack_num: int, precision: int, total_num: int) -> V: + return torch.tensor(self.coder.unpack_floats(packed, offset_bit, pack_num, precision, total_num)) + + def pack_vec(self, vec: torch.LongTensor, num_shift_bit, num_elem_each_pack) -> FV: + return self.coder.pack_u64_vec(vec.detach().tolist(), num_shift_bit, num_elem_each_pack) + + def unpack_vec(self, vec: FV, num_shift_bit, num_elem_each_pack, total_num) -> torch.LongTensor: + return torch.LongTensor(self.coder.unpack_u64_vec(vec, num_shift_bit, num_elem_each_pack, total_num)) + + def encode_tensor(self, tensor: V, dtype: torch.dtype = None) -> FV: + return self.encode_vec(tensor.flatten(), dtype=tensor.dtype) + + def decode_tensor(self, tensor: FV, dtype: torch.dtype, shape: torch.Size = None, device=None) -> V: + data = self.decode_vec(tensor, dtype) + if shape is not None: + data = data.reshape(shape) + if device is not None: + data = data.to(device.to_torch_device()) + return data + + def encode_vec(self, vec: V, dtype: torch.dtype = None) -> FV: + if dtype is None: + dtype = vec.dtype + else: + if dtype != vec.dtype: + vec = vec.to(dtype=dtype) + # if dtype == torch.float64: + # return self.encode_f64_vec(vec) + # if dtype == torch.float32: + # return self.encode_f32_vec(vec) + if dtype == torch.int64: + return self.encode_i64_vec(vec) + if dtype == torch.int32: + return self.encode_i32_vec(vec) + raise NotImplementedError(f"{vec.dtype} not supported") + + def decode_vec(self, vec: FV, dtype: torch.dtype) -> V: + # if dtype == torch.float64: + # return self.decode_f64_vec(vec) + # if dtype == torch.float32: + # return self.decode_f32_vec(vec) + if dtype == torch.int64: + return self.decode_i64_vec(vec) + if dtype == torch.int32: + return self.decode_i32_vec(vec) + raise NotImplementedError(f"{dtype} not supported") + + def encode(self, val, dtype=None) -> FV: + if isinstance(val, torch.Tensor): + assert val.ndim == 0, "only scalar supported" + dtype = val.dtype + val = val.item() + # if dtype == torch.float64: + # return self.encode_f64(val) + # if dtype == torch.float32: + # return self.encode_f32(val) + if dtype == torch.int64: + return self.encode_i64(val) + if dtype == torch.int32: + return self.encode_i32(val) + raise NotImplementedError(f"{dtype} not supported") + + # def encode_f64(self, val: float): + # return self.coder.encode_f64(val) + # + # def decode_f64(self, val): + # return self.coder.decode_f64(val) + + def encode_i64(self, val: int): + return self.coder.encode_u64(val) + + def decode_i64(self, val): + return self.coder.decode_u64(val) + + # def encode_f32(self, val: float): + # return self.coder.encode_f32(val) + # + # def decode_f32(self, val): + # return self.coder.decode_f32(val) + + def encode_i32(self, val: int): + return self.coder.encode_u32(val) + + def decode_i32(self, val): + return self.coder.decode_u32(val) + + # def encode_f64_vec(self, vec: torch.Tensor): + # vec = vec.detach().flatten() + # return self.coder.encode_f64_vec(vec.detach().numpy()) + # + # def decode_f64_vec(self, vec): + # return torch.tensor(self.coder.decode_f64_vec(vec)) + + def encode_i64_vec(self, vec: torch.Tensor): + vec = vec.detach().flatten() + return self.coder.encode_u64_vec(vec.detach().numpy()) + + def decode_i64_vec(self, vec): + return torch.tensor(self.coder.decode_u64_vec(vec)) + + # def encode_f32_vec(self, vec: torch.Tensor): + # vec = vec.detach().flatten() + # return self.coder.encode_f32_vec(vec.detach().numpy()) + # + # def decode_f32_vec(self, vec): + # return torch.tensor(self.coder.decode_f32_vec(vec)) + + def encode_i32_vec(self, vec: torch.Tensor): + vec = vec.detach().flatten() + return self.coder.encode_u32_vec(vec.detach().numpy()) + + def decode_i32_vec(self, vec): + return torch.tensor(self.coder.decode_u32_vec(vec)) + + +def keygen(key_size): + sk, pk, coder = _keygen(key_size) + return SK(sk), PK(pk), Coder(coder) + + +class evaluator(TensorEvaluator[EV, V, PK, Coder]): + @staticmethod + def add(a: EV, b: EV, pk: PK): + return a.add(pk.pk, b) + + @staticmethod + def add_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + if output_dtype is None: + output_dtype = b.dtype + encoded = coder.encode_tensor(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded(encoded, obfuscate=False) + return a.add(pk.pk, encrypted) + + @staticmethod + def add_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded_scalar(encoded, obfuscate=False) + return a.add_scalar(pk.pk, encrypted) + + @staticmethod + def sub(a: EV, b: EV, pk: PK): + return a.sub(pk.pk, b) + + @staticmethod + def sub_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + if output_dtype is None: + output_dtype = b.dtype + encoded = coder.encode_tensor(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded(encoded, obfuscate=False) + return a.sub(pk.pk, encrypted) + + @staticmethod + def sub_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded_scalar(encoded, obfuscate=False) + return a.sub_scalar(pk.pk, encrypted) + + @staticmethod + def rsub(a: EV, b: EV, pk: PK): + return a.rsub(pk.pk, b) + + @staticmethod + def rsub_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + if output_dtype is None: + output_dtype = b.dtype + encoded = coder.encode_tensor(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded(encoded, obfuscate=False) + return a.rsub(pk.pk, encrypted) + + @staticmethod + def rsub_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded_scalar(encoded, obfuscate=False) + return a.rsub_scalar(pk.pk, encrypted) + + @staticmethod + def mul_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + if output_dtype is None: + output_dtype = b.dtype + encoded = coder.encode_tensor(b, dtype=output_dtype) + return a.mul(pk.pk, encoded) + + @staticmethod + def mul_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode(b, dtype=output_dtype) + return a.mul_scalar(pk.pk, encoded) + + @staticmethod + def matmul(a: EV, b: V, a_shape, b_shape, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode_tensor(b, dtype=output_dtype) + # TODO: move this to python side so other protocols can use it without matmul support? + return a.matmul(pk.pk, encoded, a_shape, b_shape) + + @staticmethod + def rmatmul(a: EV, b: V, a_shape, b_shape, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode_tensor(b, dtype=output_dtype) + return a.rmatmul(pk.pk, encoded, a_shape, b_shape) + + @staticmethod + def zeros(size) -> EV: + return CiphertextVector.zeros(size) + + @staticmethod + def i_add(pk: PK, a: EV, b: EV, sa=0, sb=0, size: Optional[int] = None) -> None: + """ + inplace add, a[sa:sa+size] += b[sb:sb+size], if size is None, then size = min(a.size - sa, b.size - sb) + Args: + pk: the public key + a: the vector to add to + b: the vector to add + sa: the start index of a + sb: the start index of b + size: the size to add + """ + if a is b: + a.iadd_vec_self(sa, sb, size, pk.pk) + else: + a.iadd_vec(b, sa, sb, size, pk.pk) + + @staticmethod + def i_sub(pk: PK, a: EV, b: EV, sa=0, sb=0, size: Optional[int] = None) -> None: + """ + inplace sub, a[sa:sa+size] += b[sb:sb+size], if size is None, then size = min(a.size - sa, b.size - sb) + Args: + pk: the public key + a: the vector to add to + b: the vector to add + sa: the start index of a + sb: the start index of b + size: the size to add + """ + if a is b: + a.isub_vec_self(sa, sb, size, pk.pk) + else: + a.isub_vec(b, sa, sb, size, pk.pk) + + @staticmethod + def slice(a: EV, start: int, size: int) -> EV: + """ + slice a[start:start+size] + Args: + a: the vector to slice + start: the start index + size: the size to slice + + Returns: + the sliced vector + """ + return a.slice(start, size) + + @staticmethod + def i_shuffle(pk: PK, a: EV, indices: torch.LongTensor) -> None: + """ + inplace shuffle, a = a[indices] + Args: + pk: public key, not used + a: the vector to shuffle + indices: the indices to shuffle + """ + a.i_shuffle(indices) + + @staticmethod + def shuffle(pk: PK, a: EV, indices: torch.LongTensor) -> EV: + """ + shuffle, out = a[indices] + Args: + pk: public key, not used + a: the vector to shuffle + indices: the indices to shuffle + """ + return a.shuffle(indices) + + @staticmethod + def i_update(pk: PK, a: EV, b: EV, positions, stride: int) -> None: + """ + inplace update, a[positions] += b[::stride] + Args: + pk: public key, not used + a: the vector to update + b: the vector to update with + positions: the positions to update + stride: the stride to update + """ + a.iupdate(b, positions, stride, pk.pk) + + @staticmethod + def i_update_with_masks(pk: PK, a: EV, b: EV, positions, masks, stride: int) -> None: + """ + inplace update, a[positions] += b[::stride] + Args: + pk: public key, not used + a: the vector to update + b: the vector to update with + positions: the positions to update + stride: the stride to update + """ + a.iupdate_with_masks(b, positions, masks, stride, pk.pk) + + @staticmethod + def intervals_slice(a: EV, intervals: List[Tuple[int, int]]) -> EV: + """ + slice in the given intervals + + for example: + intervals=[(0, 4), (6, 12)], a = [a0, a1, a2, a3, a4, a5, a6, a7,...] + then the result is [a0, a1, a2, a3, a6, a7, a8, a9, a10, a11] + """ + return a.intervals_slice(intervals) + + @staticmethod + def cat(list: List[EV]) -> EV: + """ + concatenate the list of vectors + Args: + list: the list of vectors + + Returns: the concatenated vector + """ + return _Evaluator.cat(list) + + @staticmethod + def chunking_cumsum_with_step(pk: PK, a: EV, chunk_sizes: List[int], step: int): + """ + chunking cumsum with step size + + for example: + if step=2, chunk_sizes=[4, 2, 6], a = [a0, a1, a2, a3, a4, a5, a6, a7,...a11] + then the result is [a0, a1, a0+a2, a1+a3, a4, a5, a6, a7, a6+a8, a7+a9, a6+a8+a10, a7+a9+a11] + Args: + pk: the public key + a: the vector to cumsum + chunk_sizes: the chunk sizes, must sum to a.size + step: the step size, cumsum with skip step-1 elements + Returns: + the cumsum result + """ + return a.chunking_cumsum_with_step(pk.pk, chunk_sizes, step) + + @staticmethod + def pack_squeeze(a: EV, pack_num: int, shift_bit: int, pk: PK) -> EV: + return a.pack_squeeze(pack_num, shift_bit, pk.pk) + + +def test_pack_float(): + offset_bit = 32 + precision = 16 + coder = Coder(_Coder()) + vec = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + packed = coder.pack_floats(vec, offset_bit, 2, precision) + unpacked = coder.unpack_floats(packed, offset_bit, 2, precision, 5) + assert torch.allclose(vec, unpacked, rtol=1e-3, atol=1e-3) + + +def test_pack_squeeze(): + offset_bit = 32 + precision = 16 + pack_num = 2 + pack_packed_num = 2 + vec1 = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + vec2 = torch.tensor([0.6, 0.7, 0.8, 0.9, 1.0]) + sk, pk, coder = keygen(1024) + a = coder.pack_floats(vec1, offset_bit, pack_num, precision) + ea = pk.encrypt_encoded(a, obfuscate=False) + b = coder.pack_floats(vec2, offset_bit, pack_num, precision) + eb = pk.encrypt_encoded(b, obfuscate=False) + ec = evaluator.add(ea, eb, pk) + + # pack packed encrypted + ec_pack = evaluator.pack_squeeze(ec, pack_packed_num, offset_bit * 2, pk) + c_pack = sk.decrypt_to_encoded(ec_pack) + c = coder.unpack_floats(c_pack, offset_bit, pack_num * pack_packed_num, precision, 5) + assert torch.allclose(vec1 + vec2, c, rtol=1e-3, atol=1e-3) diff --git a/rust/fate_utils/crates/fate_utils/Cargo.toml b/rust/fate_utils/crates/fate_utils/Cargo.toml index abfd665160..8ed1e0f82e 100644 --- a/rust/fate_utils/crates/fate_utils/Cargo.toml +++ b/rust/fate_utils/crates/fate_utils/Cargo.toml @@ -29,6 +29,8 @@ math = { path = "../math" } fixedpoint = { path = "../fixedpoint" } paillier = { path = "../paillier" } fixedpoint_paillier = { path = "../fixedpoint_paillier" } +fixedpoint_ou = { path = "../fixedpoint_ou" } + [features] default = ["rug", "rayon", "std", "u64_backend", "extension-module"] diff --git a/rust/fate_utils/crates/fate_utils/src/lib.rs b/rust/fate_utils/crates/fate_utils/src/lib.rs index 5590baabc6..9b8a4056da 100644 --- a/rust/fate_utils/crates/fate_utils/src/lib.rs +++ b/rust/fate_utils/crates/fate_utils/src/lib.rs @@ -7,6 +7,8 @@ mod quantile; mod secure_aggregation_helper; mod paillier; +mod ou; + use pyo3::prelude::*; #[pymodule] @@ -16,6 +18,7 @@ fn fate_utils(py: Python, m: &PyModule) -> PyResult<()> { psi::register(py, m)?; histogram::register(py, m)?; paillier::register(py, m)?; + ou::register(py, m)?; secure_aggregation_helper::register(py, m)?; Ok(()) } diff --git a/rust/fate_utils/crates/fate_utils/src/ou/mod.rs b/rust/fate_utils/crates/fate_utils/src/ou/mod.rs new file mode 100644 index 0000000000..e3e21b16a6 --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/src/ou/mod.rs @@ -0,0 +1,12 @@ +mod ou; +use pyo3::prelude::*; + +pub(crate) fn register(py: Python, m: &PyModule) -> PyResult<()> { + let submodule = PyModule::new(py, "ou")?; + ou::register(py, submodule)?; + m.add_submodule(submodule)?; + py.import("sys")? + .getattr("modules")? + .set_item("fate_utils.ou", submodule)?; + Ok(()) +} diff --git a/rust/fate_utils/crates/fate_utils/src/ou/ou.rs b/rust/fate_utils/crates/fate_utils/src/ou/ou.rs new file mode 100644 index 0000000000..237536e8d0 --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/src/ou/ou.rs @@ -0,0 +1,443 @@ +use numpy::PyReadonlyArray1; +use pyo3::exceptions::PyRuntimeError; +use pyo3::prelude::*; +use anyhow::Error as AnyhowError; + +trait ToPyErr { + fn to_py_err(self) -> PyErr; +} + +impl ToPyErr for AnyhowError { + fn to_py_err(self) -> PyErr { + PyRuntimeError::new_err(self.to_string()) + } +} + + +#[pyclass(module = "fate_utils.ou")] +#[derive(Default)] +pub struct PK(fixedpoint_ou::PK); + +#[pyclass(module = "fate_utils.ou")] +#[derive(Default)] +pub struct SK(fixedpoint_ou::SK); + +#[pyclass(module = "fate_utils.ou")] +#[derive(Default)] +pub struct Coder(fixedpoint_ou::Coder); + +#[pyclass(module = "fate_utils.ou")] +#[derive(Default)] +pub struct Ciphertext(fixedpoint_ou::Ciphertext); + +#[pyclass(module = "fate_utils.ou")] +#[derive(Default, Debug)] +pub struct CiphertextVector(fixedpoint_ou::CiphertextVector); + +#[pyclass(module = "fate_utils.ou")] +#[derive(Default, Debug)] +pub struct PlaintextVector(fixedpoint_ou::PlaintextVector); + +#[pyclass(module = "fate_utils.ou")] +#[derive(Default)] +pub struct Plaintext(fixedpoint_ou::Plaintext); + +#[pyclass] +pub struct Evaluator {} + +#[pymethods] +impl PK { + fn encrypt_encoded( + &self, + plaintext_vector: &PlaintextVector, + obfuscate: bool, + ) -> CiphertextVector { + CiphertextVector(self.0.encrypt_encoded(&plaintext_vector.0, obfuscate)) + } + fn encrypt_encoded_scalar(&self, plaintext: &Plaintext, obfuscate: bool) -> Ciphertext { + Ciphertext(self.0.encrypt_encoded_scalar(&plaintext.0, obfuscate)) + } + + #[new] + fn __new__() -> PyResult { + Ok(PK::default()) + } + + fn __getstate__(&self) -> PyResult> { + Ok(bincode::serialize(&self.0).unwrap()) + } + + fn __setstate__(&mut self, state: Vec) -> PyResult<()> { + self.0 = bincode::deserialize(&state).unwrap(); + Ok(()) + } +} + +#[pymethods] +impl SK { + fn decrypt_to_encoded(&self, data: &CiphertextVector) -> PlaintextVector { + PlaintextVector(self.0.decrypt_to_encoded(&data.0)) + } + fn decrypt_to_encoded_scalar(&self, data: &Ciphertext) -> Plaintext { + Plaintext(self.0.decrypt_to_encoded_scalar(&data.0)) + } + + #[new] + fn __new__() -> PyResult { + Ok(SK::default()) + } + + fn __getstate__(&self) -> PyResult> { + Ok(bincode::serialize(&self.0).unwrap()) + } + + fn __setstate__(&mut self, state: Vec) -> PyResult<()> { + self.0 = bincode::deserialize(&state).unwrap(); + Ok(()) + } +} + +#[pymethods] +impl Coder { + // fn encode_f64(&self, data: f64) -> Plaintext { + // Plaintext(self.0.encode_f64(data)) + // } + // fn decode_f64(&self, data: &Plaintext) -> f64 { + // self.0.decode_f64(&data.0) + // } + // fn encode_f32(&self, data: f32) -> Plaintext { + // Plaintext(self.0.encode_f32(data)) + // } + fn encode_u64(&self, data: u64) -> Plaintext { + Plaintext(self.0.encode_u64(data)) + } + fn decode_u64(&self, data: &Plaintext) -> u64 { + self.0.decode_u64(&data.0) + } + fn encode_u32(&self, data: u32) -> Plaintext { + Plaintext(self.0.encode_u32(data)) + } + fn decode_u32(&self, data: &Plaintext) -> u32 { + self.0.decode_u32(&data.0) + } + #[new] + fn __new__() -> PyResult { + Ok(Coder::default()) + } + fn __getstate__(&self) -> PyResult> { + Ok(bincode::serialize(&self.0).unwrap()) + } + + fn __setstate__(&mut self, state: Vec) -> PyResult<()> { + self.0 = bincode::deserialize(&state).unwrap(); + Ok(()) + } + + fn pack_floats(&self, float_tensor: Vec, offset_bit: usize, pack_num: usize, precision: u32) -> PlaintextVector { + let data = self.0.pack_floats(&float_tensor, offset_bit, pack_num, precision); + PlaintextVector(fixedpoint_ou::PlaintextVector { data }) + } + + fn unpack_floats(&self, packed: &PlaintextVector, offset_bit: usize, pack_num: usize, precision: u32, total_num: usize) -> Vec { + self.0.unpack_floats(&packed.0.data, offset_bit, pack_num, precision, total_num) + } + // fn encode_f64_vec(&self, data: PyReadonlyArray1) -> PlaintextVector { + // let data = data + // .as_array() + // .iter() + // .map(|x| self.0.encode_f64(*x)) + // .collect(); + // PlaintextVector(fixedpoint_ou::PlaintextVector { data }) + // } + // fn decode_f64_vec<'py>(&self, data: &PlaintextVector, py: Python<'py>) -> &'py PyArray1 { + // Array1::from( + // data.0.data + // .iter() + // .map(|x| self.0.decode_f64(x)) + // .collect::>(), + // ) + // .into_pyarray(py) + // } + // fn encode_f32_vec(&self, data: PyReadonlyArray1) -> PlaintextVector { + // let data = data + // .as_array() + // .iter() + // .map(|x| self.0.encode_f32(*x)) + // .collect(); + // PlaintextVector(fixedpoint_ou::PlaintextVector { data }) + // } + // fn decode_f32(&self, data: &Plaintext) -> f32 { + // self.0.decode_f32(&data.0) + // } + // fn decode_f32_vec<'py>(&self, data: &PlaintextVector, py: Python<'py>) -> &'py PyArray1 { + // Array1::from( + // data.0.data + // .iter() + // .map(|x| self.0.decode_f32(x)) + // .collect::>(), + // ) + // .into_pyarray(py) + // } + fn encode_u64_vec(&self, data: PyReadonlyArray1) -> PlaintextVector { + let data = data + .as_array() + .iter() + .map(|x| self.0.encode_u64(*x)) + .collect(); + PlaintextVector(fixedpoint_ou::PlaintextVector { data }) + } + fn decode_u64_vec(&self, data: &PlaintextVector) -> Vec { + data.0.data.iter().map(|x| self.0.decode_u64(x)).collect() + } + fn encode_u32_vec(&self, data: PyReadonlyArray1) -> PlaintextVector { + let data = data + .as_array() + .iter() + .map(|x| self.0.encode_u32(*x)) + .collect(); + PlaintextVector(fixedpoint_ou::PlaintextVector { data }) + } + fn decode_u32_vec(&self, data: &PlaintextVector) -> Vec { + data.0.data.iter().map(|x| self.0.decode_u32(x)).collect() + } +} + +#[pyfunction] +fn keygen(bit_length: u32) -> (SK, PK, Coder) { + let (sk, pk, coder) = fixedpoint_ou::keygen(bit_length); + (SK(sk), PK(pk), Coder(coder)) +} + +#[pymethods] +impl CiphertextVector { + #[new] + fn __new__() -> PyResult { + Ok(CiphertextVector(fixedpoint_ou::CiphertextVector { data: vec![] })) + } + + fn __getstate__(&self) -> PyResult> { + Ok(bincode::serialize(&self.0).unwrap()) + } + + fn __setstate__(&mut self, state: Vec) -> PyResult<()> { + self.0 = bincode::deserialize(&state).unwrap(); + Ok(()) + } + + fn __len__(&self) -> usize { + self.0.data.len() + } + + fn __str__(&self) -> String { + format!("{:?}", self.0) + } + + #[staticmethod] + pub fn zeros(size: usize) -> PyResult { + Ok(CiphertextVector(fixedpoint_ou::CiphertextVector::zeros(size))) + } + + pub fn pack_squeeze(&self, pack_num: usize, offset_bit: u32, pk: &PK) -> PyResult { + Ok(CiphertextVector(self.0.pack_squeeze(&pk.0, pack_num, offset_bit))) + } + + fn slice(&mut self, start: usize, size: usize) -> CiphertextVector { + CiphertextVector(self.0.slice(start, size)) + } + + fn slice_indexes(&mut self, indexes: Vec) -> PyResult { + Ok(CiphertextVector(self.0.slice_indexes(indexes))) + } + pub fn cat(&self, others: Vec>) -> PyResult { + Ok(CiphertextVector(self.0.cat(others.iter().map(|x| &x.0).collect()))) + } + fn i_shuffle(&mut self, indexes: Vec) { + self.0.i_shuffle(indexes); + } + + fn shuffle(&self, indexes: Vec) -> PyResult { + Ok(CiphertextVector(self.0.shuffle(indexes))) + } + fn intervals_slice(&mut self, intervals: Vec<(usize, usize)>) -> PyResult { + Ok(CiphertextVector(self.0.intervals_slice(intervals).map_err(|e| e.to_py_err())?)) + } + fn iadd_slice(&mut self, pk: &PK, position: usize, other: Vec>) { + self.0.iadd_slice(&pk.0, position, other.iter().map(|x| &x.0).collect()); + } + fn iadd_vec_self( + &mut self, + sa: usize, + sb: usize, + size: Option, + pk: &PK, + ) -> PyResult<()> { + self.0.iadd_vec_self(sa, sb, size, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + fn isub_vec_self( + &mut self, + sa: usize, + sb: usize, + size: Option, + pk: &PK, + ) -> PyResult<()> { + self.0.isub_vec_self(sa, sb, size, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + + fn iadd_vec( + &mut self, + other: &CiphertextVector, + sa: usize, + sb: usize, + size: Option, + pk: &PK, + ) -> PyResult<()> { + self.0.iadd_vec(&other.0, sa, sb, size, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + + fn isub_vec( + &mut self, + other: &CiphertextVector, + sa: usize, + sb: usize, + size: Option, + pk: &PK, + ) -> PyResult<()> { + self.0.isub_vec(&other.0, sa, sb, size, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + + fn iupdate(&mut self, other: &CiphertextVector, indexes: Vec>, stride: usize, pk: &PK) -> PyResult<()> { + self.0.iupdate(&other.0, indexes, stride, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + fn iupdate_with_masks(&mut self, other: &CiphertextVector, indexes: Vec>, masks: Vec, stride: usize, pk: &PK) -> PyResult<()> { + self.0.iupdate_with_masks(&other.0, indexes, masks, stride, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + fn iadd(&mut self, pk: &PK, other: &CiphertextVector) { + self.0.iadd(&pk.0, &other.0); + } + fn idouble(&mut self, pk: &PK) { + self.0.idouble(&pk.0); + } + fn chunking_cumsum_with_step(&mut self, pk: &PK, chunk_sizes: Vec, step: usize) { + self.0.chunking_cumsum_with_step(&pk.0, chunk_sizes, step); + } + fn intervals_sum_with_step( + &mut self, + pk: &PK, + intervals: Vec<(usize, usize)>, + step: usize, + ) -> CiphertextVector { + CiphertextVector(self.0.intervals_sum_with_step(&pk.0, intervals, step)) + } + + fn tolist(&self) -> Vec { + self.0.tolist().iter().map(|x| CiphertextVector(x.clone())).collect() + } + + fn add(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + CiphertextVector(self.0.add(&pk.0, &other.0)) + } + fn add_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + CiphertextVector(self.0.add_scalar(&pk.0, &other.0)) + } + fn sub(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + CiphertextVector(self.0.sub(&pk.0, &other.0)) + } + fn sub_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + CiphertextVector(self.0.sub_scalar(&pk.0, &other.0)) + } + fn rsub(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + CiphertextVector(self.0.rsub(&pk.0, &other.0)) + } + fn rsub_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + CiphertextVector(self.0.rsub_scalar(&pk.0, &other.0)) + } + fn mul(&self, pk: &PK, other: &PlaintextVector) -> CiphertextVector { + CiphertextVector(self.0.mul(&pk.0, &other.0)) + } + fn mul_scalar(&self, pk: &PK, other: &Plaintext) -> CiphertextVector { + CiphertextVector(self.0.mul_scalar(&pk.0, &other.0)) + } + + fn matmul( + &self, + pk: &PK, + other: &PlaintextVector, + lshape: Vec, + rshape: Vec, + ) -> CiphertextVector { + CiphertextVector(self.0.matmul(&pk.0, &other.0, lshape, rshape)) + } + + fn rmatmul( + &self, + pk: &PK, + other: &PlaintextVector, + lshape: Vec, + rshape: Vec, + ) -> CiphertextVector { + CiphertextVector(self.0.rmatmul(&pk.0, &other.0, lshape, rshape)) + } +} + +#[pymethods] +impl PlaintextVector { + #[new] + fn __new__() -> PyResult { + Ok(PlaintextVector(fixedpoint_ou::PlaintextVector { data: vec![] })) + } + fn __getstate__(&self) -> PyResult> { + Ok(bincode::serialize(&self.0).unwrap()) + } + + fn __setstate__(&mut self, state: Vec) -> PyResult<()> { + self.0 = bincode::deserialize(&state).unwrap(); + Ok(()) + } + fn __str__(&self) -> String { + format!("{:?}", self.0) + } + fn get_stride(&mut self, index: usize, stride: usize) -> PlaintextVector { + PlaintextVector(self.0.get_stride(index, stride)) + } + fn tolist(&self) -> Vec { + self.0.tolist().iter().map(|x| Plaintext(x.clone())).collect() + } +} + +#[pymethods] +impl Evaluator { + #[staticmethod] + fn cat(vec_list: Vec<PyRef<CiphertextVector>>) -> PyResult<CiphertextVector> { + let mut data = vec![fixedpoint_ou::Ciphertext::zero(); 0]; + for vec in vec_list { + data.extend(vec.0.data.clone()); + } + Ok(CiphertextVector(fixedpoint_ou::CiphertextVector { data })) + } + #[staticmethod] + fn slice_indexes(a: &CiphertextVector, indexes: Vec<usize>) -> PyResult<CiphertextVector> { + let data = indexes + .iter() + .map(|i| a.0.data[*i].clone()) + .collect::<Vec<_>>(); + Ok(CiphertextVector(fixedpoint_ou::CiphertextVector { data })) + } +} + +pub(crate) fn register(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::<CiphertextVector>()?; + m.add_class::<PlaintextVector>()?; + m.add_class::<PK>()?; + m.add_class::<SK>()?; + m.add_class::<Coder>()?; + m.add_class::<Ciphertext>()?; + m.add_class::<Evaluator>()?; + m.add_function(wrap_pyfunction!(keygen, m)?)?; + Ok(()) +} diff --git a/rust/fate_utils/crates/fixedpoint_ou/Cargo.toml b/rust/fate_utils/crates/fixedpoint_ou/Cargo.toml new file mode 100644 index 0000000000..da5809fb78 --- /dev/null +++ b/rust/fate_utils/crates/fixedpoint_ou/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "fixedpoint_ou" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +serde = { workspace = true} +rug = { workspace = true } +anyhow = { workspace = true } +math = { path = "../math" } +ou = { path = "../ou" } + +[dev-dependencies] +rand = { workspace = true } diff --git a/rust/fate_utils/crates/fixedpoint_ou/src/frexp.rs b/rust/fate_utils/crates/fixedpoint_ou/src/frexp.rs new file mode 100644 index 0000000000..019f28b900 --- /dev/null +++ b/rust/fate_utils/crates/fixedpoint_ou/src/frexp.rs @@ -0,0 +1,26 @@ +use std::os::raw::{c_double, c_float, c_int}; + +extern "C" { + fn frexp(x: c_double, exp: *mut c_int) -> c_double; + fn frexpf(x: c_float, exp: *mut c_int) -> c_float; +} + +pub trait Frexp: Sized { + fn frexp(self) -> (Self, i32); +} + +impl Frexp for f64 { + fn frexp(self) -> (Self, i32) { + let mut exp: c_int = 0; + let res = unsafe { frexp(self, &mut exp) }; + (res, exp) + } +} + +impl Frexp for f32 { + fn frexp(self) -> (Self, i32) { + let mut exp: c_int = 0; + let res = unsafe { frexpf(self, &mut exp) }; + (res, exp) + } +} diff --git a/rust/fate_utils/crates/fixedpoint_ou/src/lib.rs b/rust/fate_utils/crates/fixedpoint_ou/src/lib.rs new file mode 100644 index 0000000000..f57ffe43e1 --- /dev/null +++ b/rust/fate_utils/crates/fixedpoint_ou/src/lib.rs @@ -0,0 +1,865 @@ +use math::BInt; +use ou; +use anyhow::Result; +use anyhow::anyhow; +use std::ops::{AddAssign, BitAnd, Mul, ShlAssign, SubAssign}; +use rug::{self, Integer, ops::Pow, Float, Rational}; +use serde::{Deserialize, Serialize}; + +mod frexp; + +// use frexp::Frexp; + +const BASE: u32 = 16; +// const MAX_INT_FRACTION: u8 = 2; +// const FLOAT_MANTISSA_BITS: u32 = 53; +const LOG2_BASE: u32 = 4; + +#[derive(Default, Serialize, Deserialize)] +pub struct PK { + pub pk: ou::PK, + // pub max_int: BInt, +} + +impl PK { + #[inline] + pub fn encrypt(&self, plaintext: &Plaintext, obfuscate: bool) -> Ciphertext { + let exp = plaintext.exp; + let encode = self.pk.encrypt(&plaintext.significant, obfuscate); + Ciphertext { + significant_encryped: encode, + exp, + } + } +} + +#[derive(Default, Serialize, Deserialize)] +pub struct SK { + pub sk: ou::SK, +} + +impl SK { + #[inline] + pub fn decrypt(&self, ciphertext: &Ciphertext) -> Plaintext { + let exp = ciphertext.exp; + Plaintext { + significant: self.sk.decrypt(&ciphertext.significant_encryped), + exp, + } + } +} + + +/// fixedpoint encoder +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Coder {} + +impl Coder { + pub fn new() -> Self { + Coder {} + } + + pub fn encode_u64(&self, plaintext: u64) -> Plaintext { + let significant = ou::PT(BInt::from(plaintext)); + Plaintext { + significant, + exp: 0, + } + } + pub fn pack_floats(&self, floats: &Vec<f64>, offset_bit: usize, pack_num: usize, precision: u32) -> Vec<Plaintext> { + let int_scale = Integer::from(2).pow(precision); + floats.chunks(pack_num).map(|data| { + let significant = data.iter().fold(Integer::default(), |mut x, v| { + x.shl_assign(offset_bit); + x.add_assign(Float::with_val(64, v).mul(&int_scale).round().to_integer().unwrap()); + x + }); + Plaintext { + significant: ou::PT(BInt(significant)), + exp: 0, + } + }) + .collect() + } + pub fn unpack_floats(&self, encoded: &[Plaintext], offset_bit: usize, pack_num: usize, precision: u32, expect_total_num: usize) -> Vec<f64> { + let int_scale = Integer::from(2).pow(precision); + let mut mask = Integer::from(1); + mask <<= offset_bit; + mask.sub_assign(1); + let mut result = Vec::with_capacity(expect_total_num); + let mut total_num = expect_total_num; + for x in encoded { + let n = std::cmp::min(total_num, pack_num); + let mut significant = x.significant.0.0.clone(); + let mut temp = Vec::with_capacity(n); + for _ in 0..n { + let value = Rational::from(((&significant).bitand(&mask), &int_scale)).to_f64(); + temp.push(value); + significant >>= offset_bit; + } + temp.reverse(); + result.extend(temp); + total_num -= n; + } + #[cfg(debug_assertions)] + assert_eq!(result.len(), expect_total_num); + + result + } + pub fn encode_u32(&self, plaintext: u32) -> Plaintext { + let significant = ou::PT( + BInt::from(plaintext) + ); + Plaintext { + significant, + exp: 0, + } + } + pub fn decode_u64(&self, encoded: &Plaintext) -> u64 { + let significant = encoded.significant.0.clone(); + let mantissa = significant; + (mantissa << (LOG2_BASE as i32 * encoded.exp)).to_i128() as u64 + } + pub fn decode_u32(&self, encoded: &Plaintext) -> u32 { + // Todo: could be improved + self.decode_u64(encoded) as u32 + } + + // pub fn encode_f64(&self, plaintext: f64) -> Plaintext { + // let bin_flt_exponent = plaintext.frexp().1; + // let bin_lsb_exponent = bin_flt_exponent - (FLOAT_MANTISSA_BITS as i32); + // let exp = (bin_lsb_exponent as f64 / LOG2_BASE as f64).floor() as i32; + // let significant = BInt( + // (plaintext * rug::Float::with_val(FLOAT_MANTISSA_BITS, BASE).pow(-exp)) + // .round() + // .to_integer() + // .unwrap(), + // ); + // if significant.abs_ref() > self.max_int { + // panic!( + // "Integer needs to be within +/- {} but got {}", + // self.max_int.0, &significant.0 + // ) + // } + // Plaintext { + // significant: ou::PT(significant), + // exp, + // } + // } + // pub fn decode_f64(&self, encoded: &Plaintext) -> f64 { + // let significant = encoded.significant.0.clone(); + // let mantissa = if significant > self.n { + // panic!("Attempted to decode corrupted number") + // } else if significant <= self.max_int { + // significant + // } else if significant >= BInt::from(&self.n - &self.max_int) { + // significant - &self.n + // } else { + // format!("Overflow detected in decrypted number: {:?}", significant); + // panic!("Overflow detected in decrypted number") + // }; + // if encoded.exp >= 0 { + // (mantissa << (LOG2_BASE as i32 * encoded.exp)).to_f64() + // } else { + // (mantissa * rug::Float::with_val(FLOAT_MANTISSA_BITS, BASE).pow(encoded.exp)).to_f64() + // } + // } + // pub fn encode_f32(&self, plaintext: f32) -> Plaintext { + // self.encode_f64(plaintext as f64) + // } + // pub fn decode_f32(&self, encoded: &Plaintext) -> f32 { + // self.decode_f64(encoded) as f32 + // } +} + + +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct Ciphertext { + pub significant_encryped: ou::CT, + pub exp: i32, +} + +impl Ciphertext { + pub fn zero() -> Ciphertext { + Ciphertext { + significant_encryped: ou::CT::zero(), + exp: 0, + } + } + fn decrese_exp_to(&self, exp: i32, pk: &ou::PK) -> Ciphertext { + assert!(exp < self.exp); + let factor = BInt::from(BASE).pow((self.exp - exp) as u32); + let significant_encryped = self.significant_encryped.mul_pt(&ou::PT(factor), pk); + Ciphertext { + significant_encryped, + exp, + } + } + pub fn neg(&self, pk: &PK) -> Ciphertext { + Ciphertext { + significant_encryped: ou::CT(self.significant_encryped.0.invert_ref(&pk.pk.n)), + exp: self.exp, + } + } + pub fn add_pt(&self, b: &Plaintext, pk: &PK) -> Ciphertext { + let b = pk.encrypt(b, false); + self.add(&b, pk) + } + pub fn sub_pt(&self, b: &Plaintext, pk: &PK) -> Ciphertext { + let b = pk.encrypt(b, false); + self.sub(&b, pk) + } + /* + other - self + */ + pub fn rsub_pt(&self, b: &Plaintext, pk: &PK) -> Ciphertext { + let b = pk.encrypt(b, false); + b.sub(self, pk) + } + pub fn sub(&self, b: &Ciphertext, pk: &PK) -> Ciphertext { + self.add(&b.neg(pk), pk) + } + pub fn rsub(&self, b: &Ciphertext, pk: &PK) -> Ciphertext { + self.neg(pk).add(&b, pk) + } + pub fn add_assign(&mut self, b: &Ciphertext, pk: &PK) { + // FIXME + *self = self.add(b, pk); + } + pub fn sub_assign(&mut self, b: &Ciphertext, pk: &PK) { + // FIXME + *self = self.sub(b, pk); + } + pub fn i_double(&mut self, pk: &PK) { + self.significant_encryped.0 = self + .significant_encryped + .0 + .pow_mod_ref(&BInt::from(2), &pk.pk.n); + } + + pub fn add(&self, b: &Ciphertext, pk: &PK) -> Ciphertext { + let a = self; + if a.significant_encryped.0.0 == 1 { + return b.clone(); + } + if b.significant_encryped.0.0 == 1 { + return a.clone(); + } + if a.exp > b.exp { + let a = &a.decrese_exp_to(b.exp, &pk.pk); + Ciphertext { + significant_encryped: a + .significant_encryped + .add_ct(&b.significant_encryped, &pk.pk), + exp: b.exp, + } + } else if a.exp < b.exp { + let b = &b.decrese_exp_to(a.exp, &pk.pk); + Ciphertext { + significant_encryped: a + .significant_encryped + .add_ct(&b.significant_encryped, &pk.pk), + exp: a.exp, + } + } else { + Ciphertext { + significant_encryped: a + .significant_encryped + .add_ct(&b.significant_encryped, &pk.pk), + exp: a.exp, + } + } + } + pub fn mul(&self, b: &Plaintext, pk: &PK) -> Ciphertext { + let inside = (&self.significant_encryped.0).pow_mod_ref(&b.significant.0, &pk.pk.n); + Ciphertext { + significant_encryped: ou::CT(inside), + exp: self.exp + b.exp, + } + } +} + + +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct CiphertextVector { + pub data: Vec<Ciphertext>, +} + +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct Plaintext { + pub significant: ou::PT, + pub exp: i32, +} + +#[derive(Default, Debug, Serialize, Deserialize)] +pub struct PlaintextVector { + pub data: Vec<Plaintext>, +} + +impl PK { + pub fn encrypt_encoded( + &self, + plaintext: &PlaintextVector, + obfuscate: bool, + ) -> CiphertextVector { + let data = plaintext + .data + .iter() + .map(|x| Ciphertext { significant_encryped: self.pk.encrypt(&x.significant, obfuscate), exp: x.exp }) + .collect(); + CiphertextVector { data } + } + pub fn encrypt_encoded_scalar(&self, plaintext: &Plaintext, obfuscate: bool) -> Ciphertext { + Ciphertext { + significant_encryped: self.pk.encrypt(&plaintext.significant, obfuscate), + exp: plaintext.exp, + } + } +} + + +impl SK { + pub fn decrypt_to_encoded(&self, data: &CiphertextVector) -> PlaintextVector { + let data = data.data.iter().map(|x| Plaintext { + significant: + self.sk.decrypt(&x.significant_encryped), + exp: x.exp, + }).collect(); + PlaintextVector { data } + } + pub fn decrypt_to_encoded_scalar(&self, data: &Ciphertext) -> Plaintext { + Plaintext { + significant: self.sk.decrypt(&data.significant_encryped), + exp: data.exp, + } + } +} + +pub fn keygen(bit_length: u32) -> (SK, PK, Coder) { + let (sk, pk) = ou::keygen(bit_length); + let coder = Coder::new(); + // let max_int = &sk.p / MAX_INT_FRACTION; + (SK { sk }, PK { pk: pk }, coder) +} + +impl CiphertextVector { + #[inline] + fn iadd_i_j(&mut self, pk: &PK, i: usize, j: usize, size: usize) { + let mut placeholder = Ciphertext::default(); + for k in 0..size { + placeholder = std::mem::replace(&mut self.data[i + k], placeholder); + placeholder.add_assign(&self.data[j + k], &pk); + placeholder = std::mem::replace(&mut self.data[i + k], placeholder); + } + } + #[inline] + fn isub_i_j(&mut self, pk: &PK, i: usize, j: usize, size: usize) { + let mut placeholder = Ciphertext::default(); + for k in 0..size { + placeholder = std::mem::replace(&mut self.data[i + k], placeholder); + placeholder.sub_assign(&self.data[j + k], &pk); + placeholder = std::mem::replace(&mut self.data[i + k], placeholder); + } + } + pub fn zeros(size: usize) -> Self { + let data = vec![Ciphertext::zero(); size]; + CiphertextVector { data } + } + + pub fn pack_squeeze(&self, pk: &PK, pack_num: usize, shift_bit: u32) -> CiphertextVector { + let base = BInt::from(2).pow(shift_bit); + let data = self.data.chunks(pack_num).map(|x| { + let mut result = x[0].significant_encryped.0.clone(); + for y in &x[1..] { + result.pow_mod_mut(&base, &pk.pk.n); + result = result.mul(&y.significant_encryped.0) % &pk.pk.n; + } + Ciphertext { significant_encryped: ou::CT(result), exp: 0 } + }).collect(); + CiphertextVector { data } + } + + pub fn slice(&mut self, start: usize, size: usize) -> CiphertextVector { + let data = self.data[start..start + size].to_vec(); + CiphertextVector { data } + } + + pub fn slice_indexes(&mut self, indexes: Vec<usize>) -> Self { + let data = indexes + .iter() + .map(|i| self.data[*i].clone()) + .collect::<Vec<_>>(); + CiphertextVector { data } + } + + pub fn cat(&self, others: Vec<&CiphertextVector>) -> Self { + let mut data = self.data.clone(); + for other in others { + data.extend(other.data.clone()); + } + CiphertextVector { data } + } + + pub fn i_shuffle(&mut self, indexes: Vec<usize>) { + let mut visited = vec![false; self.data.len()]; + for i in 0..self.data.len() { + if visited[i] || indexes[i] == i { + continue; + } + + let mut current = i; + let mut next = indexes[current]; + while !visited[next] && next != i { + self.data.swap(current, next); + visited[current] = true; + current = next; + next = indexes[current]; + } + visited[current] = true; + } + } + + pub fn shuffle(&self, indexes: Vec<usize>) -> Self { + let data = self.data.clone(); + let mut result = CiphertextVector { data }; + result.i_shuffle(indexes); + result + } + + pub fn intervals_slice(&mut self, intervals: Vec<(usize, usize)>) -> Result<Self> { + let mut data = vec![]; + for (start, end) in intervals { + if end > self.data.len() { + return Err(anyhow!( + "end index out of range: start={}, end={}, data_size={}", + start, + end, + self.data.len() + )); + } + data.extend_from_slice(&self.data[start..end]); + } + Ok(CiphertextVector { data }) + } + + pub fn iadd_slice(&mut self, pk: &PK, position: usize, other: Vec<&Ciphertext>) { + for (i, x) in other.iter().enumerate() { + self.data[position + i] = self.data[position + i].add(&x, &pk); + } + } + + pub fn iadd_vec_self( + &mut self, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> Result<()> { + if sa == sb { + if let Some(s) = size { + if sa + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, s={}, data_size={}", + sa, + s, + self.data.len() + )); + } + self.data[sa..sa + s] + .iter_mut() + .for_each(|x| x.i_double(&pk)); + } else { + self.data[sa..].iter_mut().for_each(|x| x.i_double(&pk)); + } + } else if sa < sb { + // it's safe to update from left to right + if let Some(s) = size { + if sb + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sb={}, s={}, data_size={}", + sb, + s, + self.data.len() + )); + } + self.iadd_i_j(&pk, sb, sa, s); + } else { + self.iadd_i_j(&pk, sb, sa, self.data.len() - sb); + } + } else { + // it's safe to update from right to left + if let Some(s) = size { + if sa + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, s={}, data_size={}", + sa, + s, + self.data.len() + )); + } + self.iadd_i_j(&pk, sa, sb, s); + } else { + self.iadd_i_j(&pk, sa, sb, self.data.len() - sa); + } + } + Ok(()) + } + pub fn isub_vec_self( + &mut self, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> Result<()> { + if sa == sb { + if let Some(s) = size { + if sa + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, s={}, data_size={}", + sa, + s, + self.data.len() + )); + } + self.data[sa..sa + s] + .iter_mut() + .for_each(|x| *x = Ciphertext::zero()); + } else { + self.data[sa..].iter_mut().for_each(|x| *x = Ciphertext::zero()); + } + } else if sa < sb { + // it's safe to update from left to right + if let Some(s) = size { + if sb + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sb={}, s={}, data_size={}", + sb, + s, + self.data.len() + )); + } + self.isub_i_j(&pk, sb, sa, s); + } else { + self.isub_i_j(&pk, sb, sa, self.data.len() - sb); + } + } else { + // it's safe to update from right to left + if let Some(s) = size { + if sa + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, s={}, data_size={}", + sa, + s, + self.data.len() + )); + } + self.isub_i_j(&pk, sa, sb, s); + } else { + self.isub_i_j(&pk, sa, sb, self.data.len() - sa); + } + } + Ok(()) + } + + pub fn iadd_vec( + &mut self, + other: &CiphertextVector, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> Result<()> { + match size { + Some(s) => { + let ea = sa + s; + let eb = sb + s; + if ea > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, ea={}, data_size={}", + sa, + ea, + self.data.len() + )); + } + if eb > other.data.len() { + return Err(anyhow!( + "end index out of range: sb={}, eb={}, data_size={}", + sb, + eb, + other.data.len() + )); + } + self.data[sa..ea] + .iter_mut() + .zip(other.data[sb..eb].iter()) + .for_each(|(x, y)| { + x.add_assign(y, &pk) + }); + } + None => { + self.data[sa..] + .iter_mut() + .zip(other.data[sb..].iter()) + .for_each(|(x, y)| x.add_assign(y, &pk)); + } + }; + Ok(()) + } + + pub fn isub_vec( + &mut self, + other: &CiphertextVector, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> Result<()> { + match size { + Some(s) => { + let ea = sa + s; + let eb = sb + s; + if ea > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, ea={}, data_size={}", + sa, + ea, + self.data.len() + )); + } + if eb > other.data.len() { + return Err(anyhow!( + "end index out of range: sb={}, eb={}, data_size={}", + sb, + eb, + other.data.len() + )); + } + self.data[sa..ea] + .iter_mut() + .zip(other.data[sb..eb].iter()) + .for_each(|(x, y)| { + x.sub_assign(y, &pk) + }); + } + None => { + self.data[sa..] + .iter_mut() + .zip(other.data[sb..].iter()) + .for_each(|(x, y)| x.sub_assign(y, &pk)); + } + }; + Ok(()) + } + + pub fn iupdate(&mut self, other: &CiphertextVector, indexes: Vec<Vec<usize>>, stride: usize, pk: &PK) -> Result<()> { + for (i, x) in indexes.iter().enumerate() { + let sb = i * stride; + for pos in x.iter() { + let sa = pos * stride; + for i in 0..stride { + self.data[sa + i].add_assign(&other.data[sb + i], &pk); + } + } + } + Ok(()) + } + pub fn iupdate_with_masks(&mut self, other: &CiphertextVector, indexes: Vec<Vec<usize>>, masks: Vec<bool>, stride: usize, pk: &PK) -> Result<()> { + for (value_pos, x) in masks.iter().enumerate().filter(|(_, &mask)| mask).map(|(i, _)| i).zip(indexes.iter()) { + let sb = value_pos * stride; + for pos in x.iter() { + let sa = pos * stride; + for i in 0..stride { + self.data[sa + i].add_assign(&other.data[sb + i], &pk); + } + } + } + Ok(()) + } + + pub fn iadd(&mut self, pk: &PK, other: &CiphertextVector) { + self.data + .iter_mut() + .zip(other.data.iter()) + .for_each(|(x, y)| x.add_assign(y, &pk)); + } + + pub fn idouble(&mut self, pk: &PK) { + // TODO: fix me, remove clone + self.data + .iter_mut() + .for_each(|x| x.add_assign(&x.clone(), &pk)); + } + + pub fn chunking_cumsum_with_step(&mut self, pk: &PK, chunk_sizes: Vec<usize>, step: usize) { + let mut placeholder = Ciphertext::zero(); + let mut i = 0; + for chunk_size in chunk_sizes { + for j in step..chunk_size { + placeholder = std::mem::replace(&mut self.data[i + j], placeholder); + placeholder.add_assign(&self.data[i + j - step], &pk); + placeholder = std::mem::replace(&mut self.data[i + j], placeholder); + } + i += chunk_size; + } + } + + pub fn intervals_sum_with_step( + &mut self, + pk: &PK, + intervals: Vec<(usize, usize)>, + step: usize, + ) -> CiphertextVector { + let mut data = vec![Ciphertext::zero(); intervals.len() * step]; + for (i, (s, e)) in intervals.iter().enumerate() { + let chunk = &mut data[i * step..(i + 1) * step]; + let sub_vec = &self.data[*s..*e]; + for (val, c) in sub_vec.iter().zip((0..step).cycle()) { + chunk[c].add_assign(val, &pk); + } + } + CiphertextVector { data } + } + + pub fn tolist(&self) -> Vec<CiphertextVector> { + self.data.iter().map(|x| CiphertextVector { data: vec![x.clone()] }).collect() + } + + pub fn add(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + let data = self + .data + .iter() + .zip(other.data.iter()) + .map(|(x, y)| x.add(y, &pk)) + .collect(); + CiphertextVector { data } + } + + pub fn add_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + let data = self.data.iter().map(|x| x.add(&other, &pk)).collect(); + CiphertextVector { data } + } + + pub fn sub(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + let data = self + .data + .iter() + .zip(other.data.iter()) + .map(|(x, y)| x.sub(y, &pk)) + .collect(); + CiphertextVector { data } + } + + pub fn sub_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + let data = self.data.iter().map(|x| x.sub(&other, &pk)).collect(); + CiphertextVector { data } + } + + pub fn rsub(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + let data = self + .data + .iter() + .zip(other.data.iter()) + .map(|(x, y)| y.sub(x, &pk)) + .collect(); + CiphertextVector { data } + } + + pub fn rsub_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + let data = self.data.iter().map(|x| other.sub(x, &pk)).collect(); + CiphertextVector { data } + } + + pub fn mul(&self, pk: &PK, other: &PlaintextVector) -> CiphertextVector { + let data = self + .data + .iter() + .zip(other.data.iter()) + .map(|(x, y)| x.mul(y, &pk)) + .collect(); + CiphertextVector { data } + } + + pub fn mul_scalar(&self, pk: &PK, other: &Plaintext) -> CiphertextVector { + let data = self + .data + .iter() + .map(|x| x.mul(&other, &pk)) + .collect(); + CiphertextVector { data } + } + + pub fn matmul( + &self, + pk: &PK, + other: &PlaintextVector, + lshape: Vec<usize>, + rshape: Vec<usize>, + ) -> CiphertextVector { + let mut data = vec![Ciphertext::zero(); lshape[0] * rshape[1]]; + for i in 0..lshape[0] { + for j in 0..rshape[1] { + for k in 0..lshape[1] { + data[i * rshape[1] + j].add_assign( + &self.data[i * lshape[1] + k].mul(&other.data[k * rshape[1] + j], &pk), + &pk, + ); + } + } + } + CiphertextVector { data } + } + + pub fn rmatmul( + &self, + pk: &PK, + other: &PlaintextVector, + lshape: Vec<usize>, + rshape: Vec<usize>, + ) -> CiphertextVector { + // rshape, lshape -> rshape[0] x lshape[1] + // other, self + // 4 x 2, 2 x 5 + // ik, kj -> ij + let mut data = vec![Ciphertext::zero(); lshape[1] * rshape[0]]; + for i in 0..rshape[0] { + // 4 + for j in 0..lshape[1] { + // 5 + for k in 0..rshape[1] { + // 2 + data[i * lshape[1] + j].add_assign( + &self.data[k * lshape[1] + j].mul(&other.data[i * rshape[1] + k], &pk), + &pk, + ); + } + } + } + CiphertextVector { data } + } +} + +impl PlaintextVector { + pub fn get_stride(&mut self, index: usize, stride: usize) -> PlaintextVector { + let start = index * stride; + let end = start + stride; + let data = self.data[start..end].to_vec(); + PlaintextVector { data } + } + pub fn tolist(&self) -> Vec<Plaintext> { + self.data + .iter() + .map(|x| x.clone()) + .collect() + } +} + +#[test] +fn test_decrypt() { + let (sk, pk, coder) = keygen(1024); + let mut data = vec![0.5, -0.5]; + let encoded = PlaintextVector { data: data.iter().map(|x| coder.encode_f64(*x)).collect() }; + let encrypted = pk.encrypt_encoded(&encoded, false); + let decrypted = sk.decrypt_to_encoded(&encrypted); + let decoded = decrypted.data.iter().map(|x| coder.decode_f64(x)).collect::<Vec<_>>(); + assert_eq!(data, decoded); +} \ No newline at end of file diff --git a/rust/fate_utils/crates/ou/Cargo.toml b/rust/fate_utils/crates/ou/Cargo.toml new file mode 100644 index 0000000000..121b7ab678 --- /dev/null +++ b/rust/fate_utils/crates/ou/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "ou" +version = "0.1.0" +edition = "2021" + +[dependencies] +math = { path = "../math" } +serde = { workspace = true } + +[dev-dependencies] +criterion = { workspace = true } +iai = { workspace = true} + +[[bench]] +name = "ou_bench" +harness = false \ No newline at end of file diff --git a/rust/fate_utils/crates/ou/benches/ou_bench.rs b/rust/fate_utils/crates/ou/benches/ou_bench.rs new file mode 100644 index 0000000000..fcccd733e4 --- /dev/null +++ b/rust/fate_utils/crates/ou/benches/ou_bench.rs @@ -0,0 +1,36 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use math::BInt; +use std::time::Duration; + +fn paillier_benchmark(c: &mut Criterion) { + let (sk, pk) = ou::keygen(1024); + let plaintext = ou::PT(BInt::from_str_radix("1234567890987654321", 10)); + let ciphertext = pk.encrypt(&plaintext, true); + let mut group = c.benchmark_group("paillier"); + + group.bench_function("keygen-1024", |b| { + b.iter(|| ou::keygen(black_box(1024))) + }); + group.bench_function("keygen-2048", |b| { + b.iter(||ou::keygen(black_box(1024))) + }); + group.bench_function("encrypt", |b| { + b.iter(|| black_box(&pk).encrypt(black_box(&plaintext), true)) + }); + group.bench_function("decrypt", |b| { + b.iter(|| black_box(&sk).decrypt(black_box(&ciphertext))) + }); + group.bench_function("add ciphertext", |b| { + b.iter(|| black_box(&ciphertext).add_ct(black_box(&ciphertext), black_box(&pk))) + }); + group.bench_function("mul plaintext", |b| { + b.iter(|| black_box(&ciphertext).mul_pt(black_box(&plaintext), black_box(&pk))) + }); +} + +criterion_group! { + name = benches; + config = Criterion::default().measurement_time(Duration::from_secs(10)); + targets = paillier_benchmark +} +criterion_main!(benches); diff --git a/rust/fate_utils/crates/ou/src/lib.rs b/rust/fate_utils/crates/ou/src/lib.rs new file mode 100644 index 0000000000..976764af8a --- /dev/null +++ b/rust/fate_utils/crates/ou/src/lib.rs @@ -0,0 +1,181 @@ +use math::{BInt, ONE}; +use serde::{Deserialize, Serialize}; +use std::fmt::{Display, Formatter}; +use std::ops::AddAssign; + +#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)] +pub struct CT(pub BInt); //ciphertext + +impl Display for CT { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "CT") + } +} + +impl Default for CT { + fn default() -> Self { + todo!() + } +} + +impl<'b> AddAssign<&'b CT> for CT { + fn add_assign(&mut self, _rhs: &'b CT) { + todo!() + } +} + +impl CT { + pub fn zero() -> CT { + CT(BInt::from(ONE)) + } + pub fn add_ct(&self, ct: &CT, pk: &PK) -> CT { + CT(&self.0 * &ct.0 % &pk.n) + } + pub fn i_double(&mut self, pk: &PK) { + self.0.pow_mod_mut(&BInt::from(2), &pk.n); + } + pub fn mul_pt(&self, b: &PT, pk: &PK) -> CT { + CT(self.0.pow_mod_ref(&b.0, &pk.n)) + } +} + +#[derive(Default, Clone, Deserialize, Serialize, Debug, PartialEq)] +pub struct PT(pub BInt); // plaintest + +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct PK { + pub n: BInt, + pub g: BInt, + pub h: BInt, +} + +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct SK { + pub p: BInt, + pub q: BInt, + pub g: BInt, + // pub n: BInt, + // // n = p * q + // p_minus_one: BInt, + // q_minus_one: BInt, + // ps: BInt, + // qs: BInt, + // p_invert: BInt, + // // p^{-1} mod q + // hp: BInt, + // hq: BInt, +} + +/// generate paillier keypairs with providing bit lenght +pub fn keygen(bit_lenght: u32) -> (SK, PK) { + let prime_bit_size = bit_lenght / 3; + let (mut p, mut q, mut n, mut g): (BInt, BInt, BInt, BInt); + loop { + p = BInt::gen_prime(prime_bit_size); + q = BInt::gen_prime(bit_lenght - 2 * prime_bit_size); + n = &p * &p * &q; + if p != q && n.significant_bits() == bit_lenght { + break; + } + } + let p2 = &p * &p; + let p_1 = &p - 1; + let n_1 = &n - 1; + loop { + g = BInt::gen_positive_integer(&n_1) + 1; + if g.pow_mod_ref(&p_1, &p2).ne(&BInt::from(1u8)) { + break; + } + } + let h = g.pow_mod_ref(&n, &n); + (SK::new(p, q, g.clone()), PK::new(n, g, h)) +} + +impl PK { + fn new(n: BInt, g: BInt, h: BInt) -> PK { + PK { n, g, h } + } + /// encrypt plaintext + /// + /// ```math + /// (plaintext \cdot n + 1)r^n \pmod{n^2} + /// ``` + pub fn encrypt(&self, plaintext: &PT, obfuscate: bool) -> CT { + let r = BInt::gen_positive_integer(&self.n); + let c = self.g.pow_mod_ref(&plaintext.0, &self.n) * self.h.pow_mod_ref(&r, &self.n); + CT(c) + } +} + +impl SK { + fn new(p: BInt, q: BInt, g: BInt) -> SK { + assert!(p != q, "p == q"); + SK { + p, + q, + g, + } + } + /// decrypt ciphertext + /// + /// crt optimization applied: + /// ```math + /// dp = \frac{(c^{p-1} \pmod{p^2})-1}{p}\cdot hp \pmod{p}\\ + /// ``` + /// ```math + /// dq = \frac{(c^{q-1} \pmod{q^2})-1}{q}\cdot hq \pmod{q}\\ + /// ``` + /// ```math + /// ((dq - dp)(p^{-1} \pmod{q}) \pmod{q})p + dp + /// ``` + pub fn decrypt(&self, c: &CT) -> PT { + let ps = &self.p * &self.p; + let p_1 = &self.p - 1; + let dp = SK::h_function(&c.0, &self.p, &p_1, &ps); + let dq = SK::h_function(&self.g, &self.p, &p_1, &ps); + let dp1 = dq.invert(&self.p); + let mut m = (dp * dp1) % &self.p; + // TODO: any better way to do this? + if m < BInt::from(0) { + m.0.add_assign(&self.p.0) + } + + PT(m) + } + #[inline] + fn h_function(c: &BInt, p: &BInt, p_1: &BInt, ps: &BInt) -> BInt { + let x = c.pow_mod_ref(p_1, ps) - ONE; + (x / p) % p + } +} + +#[test] +fn keygen_even_size() { + keygen(1024); +} + +#[test] +#[should_panic] +fn keygen_odd_size() { + keygen(1023); +} + +#[test] +fn test_decrypt() { + let (private, public) = keygen(1024); + let plaintext = PT(BInt::from(25519u32)); + let ciphertext = public.encrypt(&plaintext, true); + let decrypted = private.decrypt(&ciphertext); + assert_eq!(plaintext, decrypted) +} +#[test] +fn test_add() { + let (private, public) = keygen(1024); + let plaintext1 = PT(BInt::from(25519u32)); + let plaintext2 = PT(BInt::from(12345u32)); + let ciphertext1 = public.encrypt(&plaintext1, true); + let ciphertext2 = public.encrypt(&plaintext2, true); + let ciphertext3 = ciphertext1.add_ct(&ciphertext2, &public); + let decrypted = private.decrypt(&ciphertext3); + assert_eq!(PT(BInt::from(25519u32 + 12345u32)), decrypted) +}