Skip to content

Commit

Permalink
Merge pull request #5035 from FederatedAI/dev-2.0.0-beta-sage-paillier
Browse files Browse the repository at this point in the history
dev 2.0.0 beta sage paillier
  • Loading branch information
nemirorox authored Aug 9, 2023
2 parents bd03583 + 202c75e commit 713e08f
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 4 deletions.
9 changes: 9 additions & 0 deletions python/fate/arch/protocol/phe/paillier.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def encode_tensor(self, tensor: V, dtype: torch.dtype = None) -> FV:
dtype = tensor.dtype
return self.encode_vec(tensor.flatten(), dtype=dtype)

def pack_vec(self, vec: torch.LongTensor, num_shift_bit, num_elem_each_pack) -> FV:
return self.coder.pack_u64_vec(vec.detach().tolist(), num_shift_bit, num_elem_each_pack)

def unpack_vec(self, vec: FV, num_shift_bit, num_elem_each_pack, total_num) -> torch.LongTensor:
return torch.LongTensor(self.coder.unpack_u64_vec(vec, num_shift_bit, num_elem_each_pack, total_num))

def decode_tensor(self, tensor: FV, dtype: torch.dtype, shape: torch.Size = None, device=None) -> V:
data = self.decode_vec(tensor, dtype)
if shape is not None:
Expand All @@ -54,6 +60,9 @@ def decode_tensor(self, tensor: FV, dtype: torch.dtype, shape: torch.Size = None
def encode_vec(self, vec: V, dtype: torch.dtype = None) -> FV:
if dtype is None:
dtype = vec.dtype
else:
if dtype != vec.dtype:
vec = vec.to(dtype=dtype)
if dtype == torch.float64:
return self.encode_f64_vec(vec)
if dtype == torch.float32:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use fixedpoint::CT;
use ndarray::prelude::*;
use pyo3::prelude::*;
use super::paillier;

Expand Down
19 changes: 17 additions & 2 deletions rust/fate_utils/crates/fate_utils/src/paillier/paillier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ use ndarray::prelude::*;
use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use rug::Integer;
use math::BInt;

#[pyclass(module = "fate_utils.paillier")]
#[derive(Default)]
Expand Down Expand Up @@ -144,6 +142,23 @@ impl Coders {
.collect();
FixedpointVector { data }
}
fn pack_u64_vec(&self, data: Vec<u64>, shift_bit: usize, num_each_pack: usize) -> FixedpointVector {
FixedpointVector {
data: data.chunks(num_each_pack).map(|x| {
self.coder.pack(x, shift_bit)
}).collect::<Vec<_>>()
}
}
fn unpack_u64_vec(&self, data: &FixedpointVector, shift_bit: usize, num_each_pack: usize, total_num: usize) -> Vec<u64> {
let mut result = Vec::with_capacity(total_num);
let mut total_num = total_num;
for x in data.data.iter() {
let n = std::cmp::min(total_num, num_each_pack);
result.extend(self.coder.unpack(x, shift_bit, n));
total_num -= n;
}
result
}
fn decode_f64(&self, data: &FixedpointEncoded) -> f64 {
self.coder.decode_f64(&data.data)
}
Expand Down
31 changes: 30 additions & 1 deletion rust/fate_utils/crates/fixedpoint/src/coder.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::ops::{AddAssign, ShlAssign, SubAssign};
use super::frexp::Frexp;
use super::PT;
use crate::paillier;
use math::BInt;
use rug::{self, ops::Pow};
use rug::{self, Integer, ops::Pow};
use serde::{Deserialize, Serialize};

const FLOAT_MANTISSA_BITS: u32 = 53;
Expand Down Expand Up @@ -35,6 +36,31 @@ impl FixedpointCoder {
exp: 0,
}
}
pub fn pack(&self, plaintexts: &[u64], num_shift_bit: usize) -> PT {
let significant = plaintexts.iter().fold(Integer::default(), |mut x, v| {
x.shl_assign(num_shift_bit);
x.add_assign(v);
x
});
PT {
significant: paillier::PT(BInt(significant)),
exp: 0,
}
}
pub fn unpack(&self, encoded: &PT, num_shift_bit: usize, num: usize) -> Vec<u64> {
let mut significant = encoded.significant.0.0.clone();
let mut mask = Integer::from(1u64 << num_shift_bit);
mask.sub_assign(1);

let mut result = Vec::with_capacity(num);
for _ in 0..num {
let value = Integer::from(significant.clone() & mask.clone()).to_u64().unwrap();
result.push(value);
significant >>= num_shift_bit;
}
result.reverse();
result
}
pub fn encode_i32(&self, plaintext: i32) -> PT {
let significant = paillier::PT(if plaintext < 0 {
BInt::from(&self.n + plaintext)
Expand Down Expand Up @@ -115,6 +141,7 @@ pub trait CouldCode {
fn encode(&self, coder: &FixedpointCoder) -> PT;
fn decode(pt: &PT, coder: &FixedpointCoder) -> Self;
}

impl CouldCode for f64 {
fn encode(&self, coder: &FixedpointCoder) -> PT {
coder.encode_f64(*self)
Expand All @@ -132,6 +159,7 @@ impl CouldCode for i64 {
coder.decode_i64(pt)
}
}

impl CouldCode for i32 {
fn encode(&self, coder: &FixedpointCoder) -> PT {
coder.encode_i32(*self)
Expand All @@ -140,6 +168,7 @@ impl CouldCode for i32 {
coder.decode_i32(pt)
}
}

impl CouldCode for f32 {
fn encode(&self, coder: &FixedpointCoder) -> PT {
coder.encode_f32(*self)
Expand Down

0 comments on commit 713e08f

Please sign in to comment.