From 34193c1c16a074e28cbbcee3bac797637c7b22a6 Mon Sep 17 00:00:00 2001 From: weiwee Date: Wed, 9 Aug 2023 17:22:27 +0800 Subject: [PATCH 1/2] add pack coder Signed-off-by: weiwee --- python/fate/arch/protocol/phe/paillier.py | 6 ++++ .../fate_utils/src/paillier/evaluator.rs | 1 - .../fate_utils/src/paillier/paillier.rs | 19 ++++++++++-- .../fate_utils/crates/fixedpoint/src/coder.rs | 31 ++++++++++++++++++- 4 files changed, 53 insertions(+), 4 deletions(-) diff --git a/python/fate/arch/protocol/phe/paillier.py b/python/fate/arch/protocol/phe/paillier.py index 1123a2d1fa..4f5e9af268 100644 --- a/python/fate/arch/protocol/phe/paillier.py +++ b/python/fate/arch/protocol/phe/paillier.py @@ -43,6 +43,12 @@ def encode_tensor(self, tensor: V, dtype: torch.dtype = None) -> FV: dtype = tensor.dtype return self.encode_vec(tensor.flatten(), dtype=dtype) + def pack_vec(self, vec: torch.LongTensor, num_shift_bit, num_elem_each_pack) -> FV: + return self.coder.pack_u64_vec(vec.detach().tolist(), num_shift_bit, num_elem_each_pack) + + def unpack_vec(self, vec: FV, num_shift_bit, num_elem_each_pack, total_num) -> torch.LongTensor: + return torch.LongTensor(self.coder.unpack_u64_vec(vec, num_shift_bit, num_elem_each_pack, total_num)) + def decode_tensor(self, tensor: FV, dtype: torch.dtype, shape: torch.Size = None, device=None) -> V: data = self.decode_vec(tensor, dtype) if shape is not None: diff --git a/rust/fate_utils/crates/fate_utils/src/paillier/evaluator.rs b/rust/fate_utils/crates/fate_utils/src/paillier/evaluator.rs index 004af46631..26e42566a7 100644 --- a/rust/fate_utils/crates/fate_utils/src/paillier/evaluator.rs +++ b/rust/fate_utils/crates/fate_utils/src/paillier/evaluator.rs @@ -1,5 +1,4 @@ use fixedpoint::CT; -use ndarray::prelude::*; use pyo3::prelude::*; use super::paillier; diff --git a/rust/fate_utils/crates/fate_utils/src/paillier/paillier.rs b/rust/fate_utils/crates/fate_utils/src/paillier/paillier.rs index ca249d24d6..fee60d833c 100644 --- a/rust/fate_utils/crates/fate_utils/src/paillier/paillier.rs +++ b/rust/fate_utils/crates/fate_utils/src/paillier/paillier.rs @@ -3,8 +3,6 @@ use ndarray::prelude::*; use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1}; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; -use rug::Integer; -use math::BInt; #[pyclass(module = "fate_utils.paillier")] #[derive(Default)] @@ -144,6 +142,23 @@ impl Coders { .collect(); FixedpointVector { data } } + fn pack_u64_vec(&self, data: Vec, shift_bit: usize, num_each_pack: usize) -> FixedpointVector { + FixedpointVector { + data: data.chunks(num_each_pack).map(|x| { + self.coder.pack(x, shift_bit) + }).collect::>() + } + } + fn unpack_u64_vec(&self, data: &FixedpointVector, shift_bit: usize, num_each_pack: usize, total_num: usize) -> Vec { + let mut result = Vec::with_capacity(total_num); + let mut total_num = total_num; + for x in data.data.iter() { + let n = std::cmp::min(total_num, num_each_pack); + result.extend(self.coder.unpack(x, shift_bit, n)); + total_num -= n; + } + result + } fn decode_f64(&self, data: &FixedpointEncoded) -> f64 { self.coder.decode_f64(&data.data) } diff --git a/rust/fate_utils/crates/fixedpoint/src/coder.rs b/rust/fate_utils/crates/fixedpoint/src/coder.rs index cf9334eed3..60cc565fba 100644 --- a/rust/fate_utils/crates/fixedpoint/src/coder.rs +++ b/rust/fate_utils/crates/fixedpoint/src/coder.rs @@ -1,8 +1,9 @@ +use std::ops::{AddAssign, ShlAssign, SubAssign}; use super::frexp::Frexp; use super::PT; use crate::paillier; use math::BInt; -use rug::{self, ops::Pow}; +use rug::{self, Integer, ops::Pow}; use serde::{Deserialize, Serialize}; const FLOAT_MANTISSA_BITS: u32 = 53; @@ -35,6 +36,31 @@ impl FixedpointCoder { exp: 0, } } + pub fn pack(&self, plaintexts: &[u64], num_shift_bit: usize) -> PT { + let significant = plaintexts.iter().fold(Integer::default(), |mut x, v| { + x.shl_assign(num_shift_bit); + x.add_assign(v); + x + }); + PT { + significant: paillier::PT(BInt(significant)), + exp: 0, + } + } + pub fn unpack(&self, encoded: &PT, num_shift_bit: usize, num: usize) -> Vec { + let mut significant = encoded.significant.0.0.clone(); + let mut mask = Integer::from(1u64 << num_shift_bit); + mask.sub_assign(1); + + let mut result = Vec::with_capacity(num); + for _ in 0..num { + let value = Integer::from(significant.clone() & mask.clone()).to_u64().unwrap(); + result.push(value); + significant >>= num_shift_bit; + } + result.reverse(); + result + } pub fn encode_i32(&self, plaintext: i32) -> PT { let significant = paillier::PT(if plaintext < 0 { BInt::from(&self.n + plaintext) @@ -115,6 +141,7 @@ pub trait CouldCode { fn encode(&self, coder: &FixedpointCoder) -> PT; fn decode(pt: &PT, coder: &FixedpointCoder) -> Self; } + impl CouldCode for f64 { fn encode(&self, coder: &FixedpointCoder) -> PT { coder.encode_f64(*self) @@ -132,6 +159,7 @@ impl CouldCode for i64 { coder.decode_i64(pt) } } + impl CouldCode for i32 { fn encode(&self, coder: &FixedpointCoder) -> PT { coder.encode_i32(*self) @@ -140,6 +168,7 @@ impl CouldCode for i32 { coder.decode_i32(pt) } } + impl CouldCode for f32 { fn encode(&self, coder: &FixedpointCoder) -> PT { coder.encode_f32(*self) From 202c75e98bdc6229f7b3178f5349dfc5ccf62ca7 Mon Sep 17 00:00:00 2001 From: weiwee Date: Wed, 9 Aug 2023 17:36:04 +0800 Subject: [PATCH 2/2] fix encode Signed-off-by: weiwee --- python/fate/arch/protocol/phe/paillier.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/fate/arch/protocol/phe/paillier.py b/python/fate/arch/protocol/phe/paillier.py index 4f5e9af268..f1c5dfc3fe 100644 --- a/python/fate/arch/protocol/phe/paillier.py +++ b/python/fate/arch/protocol/phe/paillier.py @@ -60,6 +60,9 @@ def decode_tensor(self, tensor: FV, dtype: torch.dtype, shape: torch.Size = None def encode_vec(self, vec: V, dtype: torch.dtype = None) -> FV: if dtype is None: dtype = vec.dtype + else: + if dtype != vec.dtype: + vec = vec.to(dtype=dtype) if dtype == torch.float64: return self.encode_f64_vec(vec) if dtype == torch.float32: