diff --git a/Cargo.lock b/Cargo.lock index bd665bc..6797592 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -213,7 +213,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn", + "syn 2.0.46", ] [[package]] @@ -235,7 +235,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1167c583765f273205c691a0a8bff23a925aa5bc66f2448031e4a6e518cd64d7" dependencies = [ "aligned-vec", - "pulp", + "pulp 0.11.11", ] [[package]] @@ -339,6 +339,17 @@ dependencies = [ "typenum", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "digest" version = "0.10.7" @@ -412,6 +423,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" +[[package]] +name = "fastdiv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff846c3bc86867340b078d8795cddacf45243bde6f2f45da05257ebe209cfa6" + [[package]] name = "fastrand" version = "2.0.1" @@ -451,7 +468,9 @@ version = "0.1.0-beta.8" dependencies = [ "concrete-ntt", "criterion", + "derivative", "ethnum", + "fastdiv", "fhe-traits", "fhe-util", "itertools 0.12.1", @@ -462,11 +481,13 @@ dependencies = [ "proptest", "prost", "prost-build", + "pulp 0.18.9", "rand", "rand_chacha", "sha2", "thiserror", "zeroize", + "zeroize_derive", ] [[package]] @@ -761,6 +782,7 @@ version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" dependencies = [ + "bytemuck", "num-traits", ] @@ -870,7 +892,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d" dependencies = [ "proc-macro2", - "syn", + "syn 2.0.46", ] [[package]] @@ -929,7 +951,7 @@ dependencies = [ "prost", "prost-types", "regex", - "syn", + "syn 2.0.46", "tempfile", "which", ] @@ -944,7 +966,7 @@ dependencies = [ "itertools 0.11.0", "proc-macro2", "quote", - "syn", + "syn 2.0.46", ] [[package]] @@ -965,6 +987,18 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "pulp" +version = "0.18.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03457ac216146f43f921500bac4e892d5cd32b0479b929cbfc90f95cd6c599c2" +dependencies = [ + "bytemuck", + "libm", + "num-complex", + "reborrow", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -1045,6 +1079,12 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "reborrow" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" + [[package]] name = "redox_syscall" version = "0.3.5" @@ -1146,7 +1186,7 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.46", ] [[package]] @@ -1189,6 +1229,17 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ee073c9e4cd00e28217186dbe12796d692868f432bf2e97ee73bed0c56dfa01" +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.46" @@ -1230,7 +1281,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.46", ] [[package]] @@ -1325,7 +1376,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 2.0.46", "wasm-bindgen-shared", ] @@ -1347,7 +1398,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.46", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1557,5 +1608,5 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.46", ] diff --git a/Cargo.toml b/Cargo.toml index 666a3c1..f8756ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ criterion = "^0.5.1" doc-comment = "^0.3.3" env_logger = "^0.11.3" ethnum = "^1.5.0" +fastdiv = "^0.1.0" indicatif = "^0.17.8" itertools = "^0.12.1" log = "^0.4.21" diff --git a/crates/fhe-math/Cargo.toml b/crates/fhe-math/Cargo.toml index 82eda35..b844dab 100644 --- a/crates/fhe-math/Cargo.toml +++ b/crates/fhe-math/Cargo.toml @@ -21,17 +21,21 @@ fhe-traits = { version = "^0.1.0-beta.8", path = "../fhe-traits" } fhe-util = { version = "^0.1.0-beta.8", path = "../fhe-util" } concrete-ntt.workspace = true +derivative = "^2.2.0" ethnum.workspace = true +fastdiv.workspace = true itertools.workspace = true ndarray.workspace = true num-bigint.workspace = true num-bigint-dig.workspace = true num-traits.workspace = true prost.workspace = true +pulp = "^0.18.9" rand.workspace = true rand_chacha.workspace = true thiserror.workspace = true zeroize.workspace = true +zeroize_derive.workspace = true sha2.workspace = true [build-dependencies] diff --git a/crates/fhe-math/src/rns/scaler.rs b/crates/fhe-math/src/rns/scaler.rs index 7b92fc4..b3051b1 100644 --- a/crates/fhe-math/src/rns/scaler.rs +++ b/crates/fhe-math/src/rns/scaler.rs @@ -315,13 +315,13 @@ impl RnsScaler { let gamma_i = self.gamma.get_unchecked(starting_index + i); let gamma_shoup_i = self.gamma_shoup.get_unchecked(starting_index + i); - let mut yi = (qi.modulus() * 2 + let mut yi = (**qi * 2 - qi.lazy_mul_shoup(qi.reduce_u128(v), *gamma_i, *gamma_shoup_i)) as u128; if !self.scaling_factor.is_one { let wi = qi.lazy_reduce_u128(w); - yi += if w_sign { qi.modulus() * 2 - wi } else { wi } as u128; + yi += if w_sign { **qi * 2 - wi } else { wi } as u128; } debug_assert!(rests.len() <= omega_i.len()); diff --git a/crates/fhe-math/src/rq/mod.rs b/crates/fhe-math/src/rq/mod.rs index f782d51..14fa381 100644 --- a/crates/fhe-math/src/rq/mod.rs +++ b/crates/fhe-math/src/rq/mod.rs @@ -11,17 +11,16 @@ mod serialize; pub mod scaler; pub mod switcher; pub mod traits; -pub use context::Context; -pub use ops::dot_product; -use sha2::{Digest, Sha256}; - use self::{scaler::Scaler, switcher::Switcher, traits::TryConvertFrom}; use crate::{Error, Result}; +pub use context::Context; use fhe_util::sample_vec_cbd; use itertools::{izip, Itertools}; use ndarray::{s, Array2, ArrayView2, Axis}; +pub use ops::dot_product; use rand::{CryptoRng, RngCore, SeedableRng}; use rand_chacha::ChaCha8Rng; +use sha2::{Digest, Sha256}; use std::sync::Arc; use zeroize::{Zeroize, Zeroizing}; @@ -87,6 +86,16 @@ pub struct Poly { coefficients_shoup: Option>, } +// Implements zeroization of polynomials +impl Zeroize for Poly { + fn zeroize(&mut self) { + if let Some(coeffs) = self.coefficients.as_slice_mut() { + coeffs.zeroize() + } + self.zeroize_shoup() + } +} + impl AsRef for Poly { fn as_ref(&self) -> &Poly { self @@ -136,6 +145,17 @@ impl Poly { &self.representation } + /// Zeroize the shoup coefficients + fn zeroize_shoup(&mut self) { + if let Some(coeffs_shoup) = self + .coefficients_shoup + .as_mut() + .and_then(|f| f.as_slice_mut()) + { + coeffs_shoup.zeroize() + } + } + /// Change the representation of the underlying polynomial. pub fn change_representation(&mut self, to: Representation) { match self.representation { @@ -160,12 +180,7 @@ impl Poly { if to != Representation::NttShoup { // We are not sure whether this polynomial was sensitive or not, // so for security, we zeroize the Shoup coefficients. - self.coefficients_shoup - .as_mut() - .unwrap() - .as_slice_mut() - .unwrap() - .zeroize(); + self.zeroize_shoup(); self.coefficients_shoup = None } match to { @@ -203,19 +218,15 @@ impl Poly { /// Prefer the `change_representation` function to safely modify the /// polynomial representation. If the `to` representation is NttShoup, the /// coefficients are still computed correctly to avoid being in an unstable - /// state. Similarly, if we override a representation which was NttShoup, we - /// zeroize the existing Shoup coefficients. + /// state. If we override a polynomial with Shoup coefficients, we zeroize + /// them. pub unsafe fn override_representation(&mut self, to: Representation) { + if self.coefficients_shoup.is_some() { + self.zeroize_shoup(); + self.coefficients_shoup = None + } if to == Representation::NttShoup { self.compute_coefficients_shoup() - } else if self.coefficients_shoup.is_some() { - self.coefficients_shoup - .as_mut() - .unwrap() - .as_slice_mut() - .unwrap() - .zeroize(); - self.coefficients_shoup = None } self.representation = to; } @@ -438,7 +449,7 @@ impl Poly { let q_len = self.ctx.q.len(); let q_last = self.ctx.q.last().unwrap(); - let q_last_div_2 = q_last.modulus() / 2; + let q_last_div_2 = (**q_last) / 2; // Add (q_last - 1) / 2 to change from flooring to rounding let (mut q_new_polys, mut q_last_poly) = @@ -456,14 +467,14 @@ impl Poly { self.ctx.inv_last_qi_mod_qj_shoup.iter(), ) .for_each(|(coeffs, qi, inv, inv_shoup)| { - let q_last_div_2_mod_qi = qi.modulus() - qi.reduce_vt(q_last_div_2); // Up to qi.modulus() + let q_last_div_2_mod_qi = **qi - qi.reduce_vt(q_last_div_2); // Up to qi.modulus() for (coeff, q_last_coeff) in izip!(coeffs, q_last_poly.iter()) { // (x mod q_last - q_L/2) mod q_i let tmp = qi.lazy_reduce(*q_last_coeff) + q_last_div_2_mod_qi; // Up to 3 * qi.modulus() // ((x mod q_i) - (x mod q_last) + (q_L/2 mod q_i)) mod q_i // = (x - x mod q_last + q_L/2) mod q_i - *coeff += 3 * qi.modulus() - tmp; // Up to 4 * qi.modulus() + *coeff += 3 * (**qi) - tmp; // Up to 4 * qi.modulus() // q_last^{-1} * (x - x mod q_last) mod q_i *coeff = qi.mul_shoup(*coeff, *inv, *inv_shoup); @@ -481,14 +492,14 @@ impl Poly { self.ctx.inv_last_qi_mod_qj_shoup.iter(), ) .for_each(|(coeffs, qi, inv, inv_shoup)| { - let q_last_div_2_mod_qi = qi.modulus() - qi.reduce(q_last_div_2); // Up to qi.modulus() + let q_last_div_2_mod_qi = **qi - qi.reduce(q_last_div_2); // Up to qi.modulus() for (coeff, q_last_coeff) in izip!(coeffs, q_last_poly.iter()) { // (x mod q_last - q_L/2) mod q_i let tmp = qi.lazy_reduce(*q_last_coeff) + q_last_div_2_mod_qi; // Up to 3 * qi.modulus() // ((x mod q_i) - (x mod q_last) + (q_L/2 mod q_i)) mod q_i // = (x - x mod q_last + q_L/2) mod q_i - *coeff += 3 * qi.modulus() - tmp; // Up to 4 * qi.modulus() + *coeff += 3 * (**qi) - tmp; // Up to 4 * qi.modulus() // q_last^{-1} * (x - x mod q_last) mod q_i *coeff = qi.mul_shoup(*coeff, *inv, *inv_shoup); @@ -567,15 +578,6 @@ impl Poly { } } -impl Zeroize for Poly { - fn zeroize(&mut self) { - self.coefficients.as_slice_mut().unwrap().zeroize(); - if let Some(s) = self.coefficients_shoup.as_mut() { - s.as_slice_mut().unwrap().zeroize(); - } - } -} - #[cfg(test)] mod tests { use super::{switcher::Switcher, Context, Poly, Representation}; @@ -1024,7 +1026,7 @@ mod tests { ) .collect_vec() ); - reference = p_biguint.clone(); + reference.clone_from(&p_biguint); } } Ok(()) diff --git a/crates/fhe-math/src/rq/ops.rs b/crates/fhe-math/src/rq/ops.rs index 61b4a29..d2462bd 100644 --- a/crates/fhe-math/src/rq/ops.rs +++ b/crates/fhe-math/src/rq/ops.rs @@ -380,7 +380,7 @@ where .ctx .q .iter() - .map(|qi| 1u128 << (2 * qi.modulus().leading_zeros())) + .map(|qi| 1u128 << (2 * (*qi).leading_zeros())) .collect_vec(); let max_acc_ptr = max_acc.as_ptr(); diff --git a/crates/fhe-math/src/rq/scaler.rs b/crates/fhe-math/src/rq/scaler.rs index e03c3d4..b1c8a3d 100644 --- a/crates/fhe-math/src/rq/scaler.rs +++ b/crates/fhe-math/src/rq/scaler.rs @@ -142,7 +142,7 @@ mod tests { use num_bigint::BigUint; use num_traits::{One, Zero}; use rand::thread_rng; - use std::{error::Error, sync::Arc}; + use std::error::Error; // Moduli to be used in tests. static Q: &[u64; 3] = &[ @@ -161,8 +161,8 @@ mod tests { fn scaler() -> Result<(), Box> { let mut rng = thread_rng(); let ntests = 100; - let from = Arc::new(Context::new(Q, 16)?); - let to = Arc::new(Context::new(P, 16)?); + let from = Context::new_arc(Q, 16)?; + let to = Context::new_arc(P, 16)?; for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] { for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] { diff --git a/crates/fhe-math/src/zq/mod.rs b/crates/fhe-math/src/zq/mod.rs index f6b9ac0..07e438a 100644 --- a/crates/fhe-math/src/zq/mod.rs +++ b/crates/fhe-math/src/zq/mod.rs @@ -4,11 +4,15 @@ pub mod primes; +use std::ops::Deref; + use crate::errors::{Error, Result}; +use derivative::Derivative; use fhe_util::{is_prime, transcode_from_bytes, transcode_to_bytes}; use itertools::{izip, Itertools}; use num_bigint::BigUint; use num_traits::cast::ToPrimitive; +use pulp::Arch; use rand::{distributions::Uniform, CryptoRng, Rng, RngCore}; /// cond ? on_true : on_false @@ -19,7 +23,9 @@ const fn const_time_cond_select(on_true: u64, on_false: u64, cond: bool) -> u64 } /// Structure encapsulating an integer modulus up to 62 bits. -#[derive(Debug, Clone, PartialEq)] +#[derive(Derivative)] +#[derivative(PartialEq)] +#[derive(Debug, Clone)] pub struct Modulus { pub(crate) p: u64, nbits: usize, @@ -28,11 +34,22 @@ pub struct Modulus { leading_zeros: u32, pub(crate) supports_opt: bool, distribution: Uniform, + #[derivative(PartialEq = "ignore")] + arch: Arch, } // We need to declare Eq manually because of the `Uniform` member. impl Eq for Modulus {} +// Override the dereference to return the underlying modulus. +impl Deref for Modulus { + type Target = u64; + + fn deref(&self) -> &Self::Target { + &self.p + } +} + impl Modulus { /// Create a modulus from an integer of at most 62 bits. pub fn new(p: u64) -> Result { @@ -48,15 +65,11 @@ impl Modulus { leading_zeros: p.leading_zeros(), supports_opt: primes::supports_opt(p), distribution: Uniform::from(0..p), + arch: Arch::new(), }) } } - /// Returns the value of the modulus. - pub const fn modulus(&self) -> u64 { - self.p - } - /// Performs the modular addition of a and b in constant time. /// Aborts if a >= p or b >= p in debug mode. pub const fn add(&self, a: u64, b: u64) -> u64 { @@ -201,8 +214,9 @@ impl Modulus { /// debug mode. pub fn add_vec(&self, a: &mut [u64], b: &[u64]) { debug_assert_eq!(a.len(), b.len()); - - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.add(*ai, *bi)); + self.arch.dispatch(|| { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.add(*ai, *bi)) + }) } /// Modular addition of vectors in place in variable time. @@ -225,26 +239,30 @@ impl Modulus { } if n % 16 == 0 { - for i in 0..n / 16 { - add_at!(16 * i); - add_at!(16 * i + 1); - add_at!(16 * i + 2); - add_at!(16 * i + 3); - add_at!(16 * i + 4); - add_at!(16 * i + 5); - add_at!(16 * i + 6); - add_at!(16 * i + 7); - add_at!(16 * i + 8); - add_at!(16 * i + 9); - add_at!(16 * i + 10); - add_at!(16 * i + 11); - add_at!(16 * i + 12); - add_at!(16 * i + 13); - add_at!(16 * i + 14); - add_at!(16 * i + 15); - } + self.arch.dispatch(|| { + for i in 0..n / 16 { + add_at!(16 * i); + add_at!(16 * i + 1); + add_at!(16 * i + 2); + add_at!(16 * i + 3); + add_at!(16 * i + 4); + add_at!(16 * i + 5); + add_at!(16 * i + 6); + add_at!(16 * i + 7); + add_at!(16 * i + 8); + add_at!(16 * i + 9); + add_at!(16 * i + 10); + add_at!(16 * i + 11); + add_at!(16 * i + 12); + add_at!(16 * i + 13); + add_at!(16 * i + 14); + add_at!(16 * i + 15); + } + }) } else { - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.add_vt(*ai, *bi)); + self.arch.dispatch(|| { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.add_vt(*ai, *bi)) + }) } } @@ -254,8 +272,9 @@ impl Modulus { /// debug mode. pub fn sub_vec(&self, a: &mut [u64], b: &[u64]) { debug_assert_eq!(a.len(), b.len()); - - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.sub(*ai, *bi)); + self.arch.dispatch(|| { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.sub(*ai, *bi)) + }) } /// Modular subtraction of vectors in place in variable time. @@ -278,26 +297,30 @@ impl Modulus { } if n % 16 == 0 { - for i in 0..n / 16 { - sub_at!(16 * i); - sub_at!(16 * i + 1); - sub_at!(16 * i + 2); - sub_at!(16 * i + 3); - sub_at!(16 * i + 4); - sub_at!(16 * i + 5); - sub_at!(16 * i + 6); - sub_at!(16 * i + 7); - sub_at!(16 * i + 8); - sub_at!(16 * i + 9); - sub_at!(16 * i + 10); - sub_at!(16 * i + 11); - sub_at!(16 * i + 12); - sub_at!(16 * i + 13); - sub_at!(16 * i + 14); - sub_at!(16 * i + 15); - } + self.arch.dispatch(|| { + for i in 0..n / 16 { + sub_at!(16 * i); + sub_at!(16 * i + 1); + sub_at!(16 * i + 2); + sub_at!(16 * i + 3); + sub_at!(16 * i + 4); + sub_at!(16 * i + 5); + sub_at!(16 * i + 6); + sub_at!(16 * i + 7); + sub_at!(16 * i + 8); + sub_at!(16 * i + 9); + sub_at!(16 * i + 10); + sub_at!(16 * i + 11); + sub_at!(16 * i + 12); + sub_at!(16 * i + 13); + sub_at!(16 * i + 14); + sub_at!(16 * i + 15); + } + }) } else { - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.sub_vt(*ai, *bi)); + self.arch.dispatch(|| { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.sub_vt(*ai, *bi)) + }) } } @@ -309,9 +332,13 @@ impl Modulus { debug_assert_eq!(a.len(), b.len()); if self.supports_opt { - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_opt(*ai, *bi)); + self.arch.dispatch(|| { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_opt(*ai, *bi)) + }) } else { - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul(*ai, *bi)); + self.arch.dispatch(|| { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul(*ai, *bi)) + }) } } @@ -320,8 +347,10 @@ impl Modulus { /// Aborts if any of the values in a is >= p in debug mode. pub fn scalar_mul_vec(&self, a: &mut [u64], b: u64) { let b_shoup = self.shoup(b); - a.iter_mut() - .for_each(|ai| *ai = self.mul_shoup(*ai, b, b_shoup)); + self.arch.dispatch(|| { + a.iter_mut() + .for_each(|ai| *ai = self.mul_shoup(*ai, b, b_shoup)) + }) } /// Modular scalar multiplication of vectors in place in variable time. @@ -332,8 +361,10 @@ impl Modulus { /// about the values being multiplied. pub unsafe fn scalar_mul_vec_vt(&self, a: &mut [u64], b: u64) { let b_shoup = self.shoup(b); - a.iter_mut() - .for_each(|ai| *ai = self.mul_shoup_vt(*ai, b, b_shoup)); + self.arch.dispatch(|| { + a.iter_mut() + .for_each(|ai| *ai = self.mul_shoup_vt(*ai, b, b_shoup)) + }) } /// Modular multiplication of vectors in place in variable time. @@ -347,9 +378,13 @@ impl Modulus { debug_assert_eq!(a.len(), b.len()); if self.supports_opt { - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_opt_vt(*ai, *bi)); + self.arch.dispatch(|| { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_opt_vt(*ai, *bi)) + }) } else { - izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_vt(*ai, *bi)); + self.arch.dispatch(|| { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_vt(*ai, *bi)) + }) } } @@ -357,7 +392,8 @@ impl Modulus { /// /// Aborts if any of the values of the vector is >= p in debug mode. pub fn shoup_vec(&self, a: &[u64]) -> Vec { - a.iter().map(|ai| self.shoup(*ai)).collect_vec() + self.arch + .dispatch(|| a.iter().map(|ai| self.shoup(*ai)).collect_vec()) } /// Shoup modular multiplication of vectors in place in constant time. @@ -369,8 +405,10 @@ impl Modulus { debug_assert_eq!(a.len(), b_shoup.len()); debug_assert_eq!(&b_shoup, &self.shoup_vec(b)); - izip!(a.iter_mut(), b.iter(), b_shoup.iter()) - .for_each(|(ai, bi, bi_shoup)| *ai = self.mul_shoup(*ai, *bi, *bi_shoup)); + self.arch.dispatch(|| { + izip!(a.iter_mut(), b.iter(), b_shoup.iter()) + .for_each(|(ai, bi, bi_shoup)| *ai = self.mul_shoup(*ai, *bi, *bi_shoup)) + }) } /// Shoup modular multiplication of vectors in place in variable time. @@ -385,13 +423,16 @@ impl Modulus { debug_assert_eq!(a.len(), b_shoup.len()); debug_assert_eq!(&b_shoup, &self.shoup_vec(b)); - izip!(a.iter_mut(), b.iter(), b_shoup.iter()) - .for_each(|(ai, bi, bi_shoup)| *ai = self.mul_shoup_vt(*ai, *bi, *bi_shoup)); + self.arch.dispatch(|| { + izip!(a.iter_mut(), b.iter(), b_shoup.iter()) + .for_each(|(ai, bi, bi_shoup)| *ai = self.mul_shoup_vt(*ai, *bi, *bi_shoup)) + }) } /// Reduce a vector in place in constant time. pub fn reduce_vec(&self, a: &mut [u64]) { - a.iter_mut().for_each(|ai| *ai = self.reduce(*ai)); + self.arch + .dispatch(|| a.iter_mut().for_each(|ai| *ai = self.reduce(*ai))) } /// Center a value modulo p as i64 in variable time. @@ -416,7 +457,8 @@ impl Modulus { /// This function is not constant time and its timing may reveal information /// about the values being centered. pub unsafe fn center_vec_vt(&self, a: &[u64]) -> Vec { - a.iter().map(|ai| self.center_vt(*ai)).collect_vec() + self.arch + .dispatch(|| a.iter().map(|ai| self.center_vt(*ai)).collect_vec()) } /// Reduce a vector in place in variable time. @@ -425,7 +467,8 @@ impl Modulus { /// This function is not constant time and its timing may reveal information /// about the values being reduced. pub unsafe fn reduce_vec_vt(&self, a: &mut [u64]) { - a.iter_mut().for_each(|ai| *ai = self.reduce_vt(*ai)); + self.arch + .dispatch(|| a.iter_mut().for_each(|ai| *ai = self.reduce_vt(*ai))) } /// Modular reduction of a i64 in constant time. @@ -444,7 +487,8 @@ impl Modulus { /// Reduce a vector in place in constant time. pub fn reduce_vec_i64(&self, a: &[i64]) -> Vec { - a.iter().map(|ai| self.reduce_i64(*ai)).collect_vec() + self.arch + .dispatch(|| a.iter().map(|ai| self.reduce_i64(*ai)).collect_vec()) } /// Reduce a vector in place in variable time. @@ -453,12 +497,14 @@ impl Modulus { /// This function is not constant time and its timing may reveal information /// about the values being reduced. pub unsafe fn reduce_vec_i64_vt(&self, a: &[i64]) -> Vec { - a.iter().map(|ai| self.reduce_i64_vt(*ai)).collect_vec() + self.arch + .dispatch(|| a.iter().map(|ai| self.reduce_i64_vt(*ai)).collect()) } /// Reduce a vector in constant time. pub fn reduce_vec_new(&self, a: &[u64]) -> Vec { - a.iter().map(|ai| self.reduce(*ai)).collect_vec() + self.arch + .dispatch(|| a.iter().map(|ai| self.reduce(*ai)).collect()) } /// Reduce a vector in variable time. @@ -467,14 +513,16 @@ impl Modulus { /// This function is not constant time and its timing may reveal information /// about the values being reduced. pub unsafe fn reduce_vec_new_vt(&self, a: &[u64]) -> Vec { - a.iter().map(|bi| self.reduce_vt(*bi)).collect_vec() + self.arch + .dispatch(|| a.iter().map(|bi| self.reduce_vt(*bi)).collect()) } /// Modular negation of a vector in place in constant time. /// /// Aborts if any of the values in the vector is >= p in debug mode. pub fn neg_vec(&self, a: &mut [u64]) { - izip!(a.iter_mut()).for_each(|ai| *ai = self.neg(*ai)); + self.arch + .dispatch(|| a.iter_mut().for_each(|ai| *ai = self.neg(*ai))) } /// Modular negation of a vector in place in variable time. @@ -484,7 +532,8 @@ impl Modulus { /// This function is not constant time and its timing may reveal information /// about the values being negated. pub unsafe fn neg_vec_vt(&self, a: &mut [u64]) { - izip!(a.iter_mut()).for_each(|ai| *ai = self.neg_vt(*ai)); + self.arch + .dispatch(|| a.iter_mut().for_each(|ai| *ai = self.neg_vt(*ai))) } /// Modular exponentiation in variable time. @@ -758,19 +807,19 @@ mod tests { prop_assume!(p >> 2 >= 2); let q = Modulus::new(p >> 2); prop_assert!(q.is_ok()); - prop_assert_eq!(q.unwrap().modulus(), p >> 2); + prop_assert_eq!(*q.unwrap(), p >> 2); } #[test] fn neg(p in valid_moduli(), mut a: u64) { a = p.reduce(a); - prop_assert_eq!(p.neg(a), (p.modulus() - a) % p.modulus()); - unsafe { prop_assert_eq!(p.neg_vt(a), (p.modulus() - a) % p.modulus()) } + prop_assert_eq!(p.neg(a), (*p - a) % *p); + unsafe { prop_assert_eq!(p.neg_vt(a), (*p - a) % *p) } #[cfg(debug_assertions)] { - prop_assert!(std::panic::catch_unwind(|| p.neg(p.modulus())).is_err()); - prop_assert!(std::panic::catch_unwind(|| p.neg(p.modulus() + 1)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.neg(*p)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.neg(*p + 1)).is_err()); } } @@ -778,15 +827,15 @@ mod tests { fn add(p in valid_moduli(), mut a: u64, mut b: u64) { a = p.reduce(a); b = p.reduce(b); - prop_assert_eq!(p.add(a, b), (a + b) % p.modulus()); - unsafe { prop_assert_eq!(p.add_vt(a, b), (a + b) % p.modulus()) } + prop_assert_eq!(p.add(a, b), (a + b) % *p); + unsafe { prop_assert_eq!(p.add_vt(a, b), (a + b) % *p) } #[cfg(debug_assertions)] { - prop_assert!(std::panic::catch_unwind(|| p.add(p.modulus(), a)).is_err()); - prop_assert!(std::panic::catch_unwind(|| p.add(a, p.modulus())).is_err()); - prop_assert!(std::panic::catch_unwind(|| p.add(p.modulus() + 1, a)).is_err()); - prop_assert!(std::panic::catch_unwind(|| p.add(a, p.modulus() + 1)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.add(*p, a)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.add(a, *p)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.add(*p + 1, a)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.add(a, *p + 1)).is_err()); } } @@ -794,15 +843,15 @@ mod tests { fn sub(p in valid_moduli(), mut a: u64, mut b: u64) { a = p.reduce(a); b = p.reduce(b); - prop_assert_eq!(p.sub(a, b), (a + p.modulus() - b) % p.modulus()); - unsafe { prop_assert_eq!(p.sub_vt(a, b), (a + p.modulus() - b) % p.modulus()) } + prop_assert_eq!(p.sub(a, b), (a + *p - b) % *p); + unsafe { prop_assert_eq!(p.sub_vt(a, b), (a + *p - b) % *p) } #[cfg(debug_assertions)] { - prop_assert!(std::panic::catch_unwind(|| p.sub(p.modulus(), a)).is_err()); - prop_assert!(std::panic::catch_unwind(|| p.sub(a, p.modulus())).is_err()); - prop_assert!(std::panic::catch_unwind(|| p.sub(p.modulus() + 1, a)).is_err()); - prop_assert!(std::panic::catch_unwind(|| p.sub(a, p.modulus() + 1)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.sub(*p, a)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.sub(a, *p)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.sub(*p + 1, a)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.sub(a, *p + 1)).is_err()); } } @@ -810,15 +859,15 @@ mod tests { fn mul(p in valid_moduli(), mut a: u64, mut b: u64) { a = p.reduce(a); b = p.reduce(b); - prop_assert_eq!(p.mul(a, b) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128)); - unsafe { prop_assert_eq!(p.mul_vt(a, b) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128)) } + prop_assert_eq!(p.mul(a, b) as u128, ((a as u128) * (b as u128)) % (*p as u128)); + unsafe { prop_assert_eq!(p.mul_vt(a, b) as u128, ((a as u128) * (b as u128)) % (*p as u128)) } #[cfg(debug_assertions)] { - prop_assert!(std::panic::catch_unwind(|| p.mul(p.modulus(), a)).is_err()); - prop_assert!(std::panic::catch_unwind(|| p.mul(a, p.modulus())).is_err()); - prop_assert!(std::panic::catch_unwind(|| p.mul(p.modulus() + 1, a)).is_err()); - prop_assert!(std::panic::catch_unwind(|| p.mul(a, p.modulus() + 1)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.mul(*p, a)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.mul(a, *p)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.mul(*p + 1, a)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.mul(a, *p + 1)).is_err()); } } @@ -832,18 +881,18 @@ mod tests { #[cfg(debug_assertions)] { - prop_assert!(std::panic::catch_unwind(|| p.shoup(p.modulus())).is_err()); - prop_assert!(std::panic::catch_unwind(|| p.shoup(p.modulus() + 1)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.shoup(*p)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.shoup(*p + 1)).is_err()); } // Check that the multiplication yields the expected result - prop_assert_eq!(p.mul_shoup(a, b, b_shoup) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128)); - unsafe { prop_assert_eq!(p.mul_shoup_vt(a, b, b_shoup) as u128, ((a as u128) * (b as u128)) % (p.modulus() as u128)) } + prop_assert_eq!(p.mul_shoup(a, b, b_shoup) as u128, ((a as u128) * (b as u128)) % (*p as u128)); + unsafe { prop_assert_eq!(p.mul_shoup_vt(a, b, b_shoup) as u128, ((a as u128) * (b as u128)) % (*p as u128)) } // Check that the multiplication with incorrect b_shoup panics in debug mode #[cfg(debug_assertions)] { - prop_assert!(std::panic::catch_unwind(|| p.mul_shoup(a, p.modulus(), b_shoup)).is_err()); + prop_assert!(std::panic::catch_unwind(|| p.mul_shoup(a, *p, b_shoup)).is_err()); prop_assume!(a != b); prop_assert!(std::panic::catch_unwind(|| p.mul_shoup(a, a, b_shoup)).is_err()); } @@ -851,18 +900,18 @@ mod tests { #[test] fn reduce(p in valid_moduli(), a: u64) { - prop_assert_eq!(p.reduce(a), a % p.modulus()); - unsafe { prop_assert_eq!(p.reduce_vt(a), a % p.modulus()) } + prop_assert_eq!(p.reduce(a), a % *p); + unsafe { prop_assert_eq!(p.reduce_vt(a), a % *p) } if p.supports_opt { - prop_assert_eq!(p.reduce_opt(a), a % p.modulus()); - unsafe { prop_assert_eq!(p.reduce_opt_vt(a), a % p.modulus()) } + prop_assert_eq!(p.reduce_opt(a), a % *p); + unsafe { prop_assert_eq!(p.reduce_opt_vt(a), a % *p) } } } #[test] fn lazy_reduce(p in valid_moduli(), a: u64) { - prop_assert!(p.lazy_reduce(a) < 2 * p.modulus()); - prop_assert_eq!(p.lazy_reduce(a) % p.modulus(), p.reduce(a)); + prop_assert!(p.lazy_reduce(a) < 2 * *p); + prop_assert_eq!(p.lazy_reduce(a) % *p, p.reduce(a)); } #[test] @@ -874,13 +923,13 @@ mod tests { #[test] fn reduce_u128(p in valid_moduli(), mut a: u128) { - prop_assert_eq!(p.reduce_u128(a) as u128, a % (p.modulus() as u128)); - unsafe { prop_assert_eq!(p.reduce_u128_vt(a) as u128, a % (p.modulus() as u128)) } + prop_assert_eq!(p.reduce_u128(a) as u128, a % (*p as u128)); + unsafe { prop_assert_eq!(p.reduce_u128_vt(a) as u128, a % (*p as u128)) } if p.supports_opt { - let p_square = (p.modulus() as u128) * (p.modulus() as u128); + let p_square = (*p as u128) * (*p as u128); a %= p_square; - prop_assert_eq!(p.reduce_opt_u128(a) as u128, a % (p.modulus() as u128)); - unsafe { prop_assert_eq!(p.reduce_opt_u128_vt(a) as u128, a % (p.modulus() as u128)) } + prop_assert_eq!(p.reduce_opt_u128(a) as u128, a % (*p as u128)); + unsafe { prop_assert_eq!(p.reduce_opt_u128_vt(a) as u128, a % (*p as u128)) } } } @@ -890,8 +939,8 @@ mod tests { p.reduce_vec(&mut b); let c = a.clone(); p.add_vec(&mut a, &b); - prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.add(*bi, *ci)).collect_vec()); - a = c.clone(); + prop_assert_eq!(a.clone(), izip!(b.iter(), c.iter()).map(|(bi, ci)| p.add(*bi, *ci)).collect_vec()); + a.clone_from(&c); unsafe { p.add_vec_vt(&mut a, &b) } prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.add(*bi, *ci)).collect_vec()); } @@ -902,8 +951,8 @@ mod tests { p.reduce_vec(&mut b); let c = a.clone(); p.sub_vec(&mut a, &b); - prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.sub(*ci, *bi)).collect_vec()); - a = c.clone(); + prop_assert_eq!(a.clone(), izip!(b.iter(), c.iter()).map(|(bi, ci)| p.sub(*ci, *bi)).collect_vec()); + a.clone_from(&c); unsafe { p.sub_vec_vt(&mut a, &b) } prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.sub(*ci, *bi)).collect_vec()); } @@ -914,8 +963,8 @@ mod tests { p.reduce_vec(&mut b); let c = a.clone(); p.mul_vec(&mut a, &b); - prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec()); - a = c.clone(); + prop_assert_eq!(a.clone(), izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec()); + a.clone_from(&c); unsafe { p.mul_vec_vt(&mut a, &b); } prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec()); } @@ -927,9 +976,9 @@ mod tests { let c = a.clone(); p.scalar_mul_vec(&mut a, b); - prop_assert_eq!(a, c.iter().map(|ci| p.mul(*ci, b)).collect_vec()); + prop_assert_eq!(a.clone(), c.iter().map(|ci| p.mul(*ci, b)).collect_vec()); - a = c.clone(); + a.clone_from(&c); unsafe { p.scalar_mul_vec_vt(&mut a, b) } prop_assert_eq!(a, c.iter().map(|ci| p.mul(*ci, b)).collect_vec()); } @@ -941,8 +990,8 @@ mod tests { let b_shoup = p.shoup_vec(&b); let c = a.clone(); p.mul_shoup_vec(&mut a, &b, &b_shoup); - prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec()); - a = c.clone(); + prop_assert_eq!(a.clone(), izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec()); + a.clone_from(&c); unsafe { p.mul_shoup_vec_vt(&mut a, &b, &b_shoup) } prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec()); } @@ -951,9 +1000,9 @@ mod tests { fn reduce_vec(p in valid_moduli(), a: Vec) { let mut b = a.clone(); p.reduce_vec(&mut b); - prop_assert_eq!(b, a.iter().map(|ai| p.reduce(*ai)).collect_vec()); + prop_assert_eq!(b.clone(), a.iter().map(|ai| p.reduce(*ai)).collect_vec()); - b = a.clone(); + b.clone_from(&a); unsafe { p.reduce_vec_vt(&mut b) } prop_assert_eq!(b, a.iter().map(|ai| p.reduce(*ai)).collect_vec()); } @@ -962,8 +1011,8 @@ mod tests { fn lazy_reduce_vec(p in valid_moduli(), a: Vec) { let mut b = a.clone(); p.lazy_reduce_vec(&mut b); - prop_assert!(b.iter().all(|bi| *bi < 2 * p.modulus())); - prop_assert!(izip!(a, b).all(|(ai, bi)| bi % p.modulus() == ai % p.modulus())); + prop_assert!(b.iter().all(|bi| *bi < 2 * *p)); + prop_assert!(izip!(a, b).all(|(ai, bi)| bi % *p == ai % *p)); } #[test] @@ -986,8 +1035,8 @@ mod tests { p.reduce_vec(&mut a); let mut b = a.clone(); p.neg_vec(&mut b); - prop_assert_eq!(b, a.iter().map(|ai| p.neg(*ai)).collect_vec()); - b = a.clone(); + prop_assert_eq!(b.clone(), a.iter().map(|ai| p.neg(*ai)).collect_vec()); + b.clone_from(&a); unsafe { p.neg_vec_vt(&mut b); } prop_assert_eq!(b, a.iter().map(|ai| p.neg(*ai)).collect_vec()); } @@ -1002,7 +1051,7 @@ mod tests { let w = p.random_vec(size, &mut rng); prop_assert_eq!(w.len(), size); - if p.modulus().leading_zeros() <= 30 { + if (*p).leading_zeros() <= 30 { prop_assert_ne!(v, w); // This will hold with probability at least 2^(-30) } } diff --git a/crates/fhe/examples/sealpir.rs b/crates/fhe/examples/sealpir.rs index d1a824a..46a5b8a 100644 --- a/crates/fhe/examples/sealpir.rs +++ b/crates/fhe/examples/sealpir.rs @@ -171,7 +171,7 @@ fn main() -> Result<(), Box> { .div_ceil(plaintext_modulus.ilog2() as usize), ); pt_values.append(&mut transcode_bidirectional( - c.get(0).unwrap().coefficients().as_slice().unwrap(), + c.first().unwrap().coefficients().as_slice().unwrap(), 64 - params.moduli()[0].leading_zeros() as usize, plaintext_modulus.ilog2() as usize, )); @@ -181,20 +181,19 @@ fn main() -> Result<(), Box> { plaintext_modulus.ilog2() as usize, )); unsafe { - Ok(bfv::PlaintextVec::try_encode_vt( + bfv::PlaintextVec::try_encode_vt( &pt_values, bfv::Encoding::poly_at_level(1), ¶ms, - )? - .0) + ) } }) - .collect::>>>()?; + .collect::>>()?; (0..fold[0].len()) .map(|i| { let mut outi = bfv::dot_product_scalar( expanded_query[dim1..].iter(), - fold.iter().map(|pts| pts.get(i).unwrap()), + fold.iter().map(|pts| &pts[i]), )?; outi.mod_switch_to_last_level()?; Ok(outi.to_bytes()) diff --git a/crates/fhe/src/bfv/ciphertext.rs b/crates/fhe/src/bfv/ciphertext.rs index 995f0c5..354b11b 100644 --- a/crates/fhe/src/bfv/ciphertext.rs +++ b/crates/fhe/src/bfv/ciphertext.rs @@ -10,6 +10,7 @@ use fhe_traits::{ use prost::Message; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; +use std::ops::{Deref, DerefMut}; use std::sync::Arc; /// A ciphertext encrypting a plaintext. @@ -28,6 +29,20 @@ pub struct Ciphertext { pub(crate) level: usize, } +impl Deref for Ciphertext { + type Target = [Poly]; + + fn deref(&self) -> &Self::Target { + &self.c + } +} + +impl DerefMut for Ciphertext { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.c + } +} + impl Ciphertext { /// Modulo switch the ciphertext to the last level. pub fn mod_switch_to_last_level(&mut self) -> Result<()> { @@ -44,6 +59,11 @@ impl Ciphertext { Ok(()) } + /// Truncate the underlying vector of polynomials. + pub(crate) fn truncate(&mut self, len: usize) { + self.c.truncate(len) + } + /// Modulo switch the ciphertext to the next level. pub fn mod_switch_to_next_level(&mut self) -> Result<()> { if self.level < self.par.max_level() { @@ -89,11 +109,6 @@ impl Ciphertext { level, }) } - - /// Get the i-th polynomial of the ciphertext. - pub fn get(&self, i: usize) -> Option<&Poly> { - self.c.get(i) - } } impl FheCiphertext for Ciphertext {} @@ -136,13 +151,13 @@ impl Ciphertext { impl From<&Ciphertext> for CiphertextProto { fn from(ct: &Ciphertext) -> Self { let mut proto = CiphertextProto::default(); - for i in 0..ct.c.len() - 1 { - proto.c.push(ct.c[i].to_bytes()) + for i in 0..ct.len() - 1 { + proto.c.push(ct[i].to_bytes()) } if let Some(seed) = ct.seed { proto.seed = seed.to_vec() } else { - proto.c.push(ct.c[ct.c.len() - 1].to_bytes()) + proto.c.push(ct[ct.len() - 1].to_bytes()) } proto.level = ct.level as u32; proto @@ -252,9 +267,9 @@ mod tests { let ct: Ciphertext = sk.try_encrypt(&pt, &mut rng)?; let mut ct3 = &ct * &ct; - let c0 = ct3.get(0).unwrap(); - let c1 = ct3.get(1).unwrap(); - let c2 = ct3.get(2).unwrap(); + let c0 = &ct3[0]; + let c1 = &ct3[1]; + let c2 = &ct3[2]; assert_eq!( ct3, @@ -264,7 +279,7 @@ mod tests { ct3.mod_switch_to_last_level()?; - let c0 = ct3.get(0).unwrap(); + let c0 = ct3.first().unwrap(); let c1 = ct3.get(1).unwrap(); let c2 = ct3.get(2).unwrap(); assert_eq!( diff --git a/crates/fhe/src/bfv/keys/evaluation_key.rs b/crates/fhe/src/bfv/keys/evaluation_key.rs index ab379e2..2426dd2 100644 --- a/crates/fhe/src/bfv/keys/evaluation_key.rs +++ b/crates/fhe/src/bfv/keys/evaluation_key.rs @@ -141,7 +141,7 @@ impl EvaluationKey { /// ciphertexts. pub fn expands(&self, ct: &Ciphertext, size: usize) -> Result> { let level = size.next_power_of_two().ilog2() as usize; - if ct.c.len() != 2 { + if ct.len() != 2 { Err(Error::DefaultError( "The ciphertext is not of size 2".to_string(), )) @@ -160,8 +160,8 @@ impl EvaluationKey { let sub = gk.relinearize(&out[i])?; if (1 << l) | i < size { out[(1 << l) | i] = &out[i] - ⊂ - out[(1 << l) | i].c[0] *= monomial; - out[(1 << l) | i].c[1] *= monomial; + out[(1 << l) | i][0] *= monomial; + out[(1 << l) | i][1] *= monomial; } out[i] += ⊂ } diff --git a/crates/fhe/src/bfv/keys/galois_key.rs b/crates/fhe/src/bfv/keys/galois_key.rs index 65a9a7b..42b0e8c 100644 --- a/crates/fhe/src/bfv/keys/galois_key.rs +++ b/crates/fhe/src/bfv/keys/galois_key.rs @@ -64,22 +64,22 @@ impl GaloisKey { /// Relinearize a [`Ciphertext`] using the [`GaloisKey`] pub fn relinearize(&self, ct: &Ciphertext) -> Result { // assert_eq!(ct.par, self.ksk.par); - assert_eq!(ct.c.len(), 2); + assert_eq!(ct.len(), 2); - let mut c2 = ct.c[1].substitute(&self.element)?; + let mut c2 = ct[1].substitute(&self.element)?; c2.change_representation(Representation::PowerBasis); let (mut c0, mut c1) = self.ksk.key_switch(&c2)?; - if c0.ctx() != ct.c[0].ctx() { + if c0.ctx() != ct[0].ctx() { c0.change_representation(Representation::PowerBasis); c1.change_representation(Representation::PowerBasis); - c0.mod_switch_down_to(ct.c[0].ctx())?; - c1.mod_switch_down_to(ct.c[1].ctx())?; + c0.mod_switch_down_to(ct[0].ctx())?; + c1.mod_switch_down_to(ct[1].ctx())?; c0.change_representation(Representation::Ntt); c1.change_representation(Representation::Ntt); } - c0 += &ct.c[0].substitute(&self.element)?; + c0 += &ct[0].substitute(&self.element)?; Ok(Ciphertext { par: ct.par.clone(), diff --git a/crates/fhe/src/bfv/keys/public_key.rs b/crates/fhe/src/bfv/keys/public_key.rs index f0d61cf..4f2d8c8 100644 --- a/crates/fhe/src/bfv/keys/public_key.rs +++ b/crates/fhe/src/bfv/keys/public_key.rs @@ -27,7 +27,7 @@ impl PublicKey { let mut c: Ciphertext = sk.try_encrypt(&zero, rng).unwrap(); // The polynomials of a public key should not allow for variable time // computation. - c.c.iter_mut() + c.iter_mut() .for_each(|p| p.disallow_variable_time_computations()); Self { par: sk.par.clone(), @@ -74,10 +74,10 @@ impl FheEncrypter for PublicKey { )?); let m = Zeroizing::new(pt.to_poly()); - let mut c0 = u.as_ref() * &ct.c[0]; + let mut c0 = u.as_ref() * &ct[0]; c0 += &e1; c0 += &m; - let mut c1 = u.as_ref() * &ct.c[1]; + let mut c1 = u.as_ref() * &ct[1]; c1 += &e2; // It is now safe to enable variable time computations. @@ -122,7 +122,7 @@ impl DeserializeParametrized for PublicKey { } else { // The polynomials of a public key should not allow for variable time // computation. - c.c.iter_mut() + c.iter_mut() .for_each(|p| p.disallow_variable_time_computations()); Ok(Self { par: par.clone(), diff --git a/crates/fhe/src/bfv/keys/relinearization_key.rs b/crates/fhe/src/bfv/keys/relinearization_key.rs index b16ece6..7ac26a6 100644 --- a/crates/fhe/src/bfv/keys/relinearization_key.rs +++ b/crates/fhe/src/bfv/keys/relinearization_key.rs @@ -72,7 +72,7 @@ impl RelinearizationKey { /// Relinearize an "extended" ciphertext (c0, c1, c2) into a [`Ciphertext`] pub fn relinearizes(&self, ct: &mut Ciphertext) -> Result<()> { - if ct.c.len() != 3 { + if ct.len() != 3 { Err(Error::DefaultError( "Only supports relinearization of ciphertext with 3 parts".to_string(), )) @@ -81,24 +81,24 @@ impl RelinearizationKey { "Ciphertext has incorrect level".to_string(), )) } else { - let mut c2 = ct.c[2].clone(); + let mut c2 = ct[2].clone(); c2.change_representation(Representation::PowerBasis); #[allow(unused_mut)] let (mut c0, mut c1) = self.relinearizes_poly(&c2)?; - if c0.ctx() != ct.c[0].ctx() { + if c0.ctx() != ct[0].ctx() { c0.change_representation(Representation::PowerBasis); c1.change_representation(Representation::PowerBasis); - c0.mod_switch_down_to(ct.c[0].ctx())?; - c1.mod_switch_down_to(ct.c[1].ctx())?; + c0.mod_switch_down_to(ct[0].ctx())?; + c1.mod_switch_down_to(ct[1].ctx())?; c0.change_representation(Representation::Ntt); c1.change_representation(Representation::Ntt); } - ct.c[0] += &c0; - ct.c[1] += &c1; - ct.c.truncate(2); + ct[0] += &c0; + ct[1] += &c1; + ct.truncate(2); Ok(()) } } @@ -193,7 +193,7 @@ mod tests { // Relinearize the extended ciphertext! rk.relinearizes(&mut ct)?; - assert_eq!(ct.c.len(), 2); + assert_eq!(ct.len(), 2); // Check that the relinearization by polynomials works the same way c2.change_representation(Representation::PowerBasis); @@ -254,7 +254,7 @@ mod tests { // Relinearize the extended ciphertext! rk.relinearizes(&mut ct)?; - assert_eq!(ct.c.len(), 2); + assert_eq!(ct.len(), 2); // Check that the relinearization by polynomials works the same way c2.change_representation(Representation::PowerBasis); diff --git a/crates/fhe/src/bfv/keys/secret_key.rs b/crates/fhe/src/bfv/keys/secret_key.rs index 859aa15..a1ac9b8 100644 --- a/crates/fhe/src/bfv/keys/secret_key.rs +++ b/crates/fhe/src/bfv/keys/secret_key.rs @@ -13,23 +13,17 @@ use num_bigint::BigUint; use rand::{thread_rng, CryptoRng, Rng, RngCore, SeedableRng}; use rand_chacha::ChaCha8Rng; use std::sync::Arc; -use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing}; +use zeroize::Zeroizing; +use zeroize_derive::{Zeroize, ZeroizeOnDrop}; /// Secret key for the BFV encryption scheme. -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone, Zeroize, ZeroizeOnDrop)] pub struct SecretKey { + #[zeroize(skip)] pub(crate) par: Arc, pub(crate) coeffs: Box<[i64]>, } -impl Zeroize for SecretKey { - fn zeroize(&mut self) { - self.coeffs.zeroize(); - } -} - -impl ZeroizeOnDrop for SecretKey {} - impl SecretKey { /// Generate a random [`SecretKey`]. pub fn random(par: &Arc, rng: &mut R) -> Self { @@ -40,7 +34,7 @@ impl SecretKey { /// Generate a [`SecretKey`] from its coefficients. pub(crate) fn new(coeffs: Vec, par: &Arc) -> Self { Self { - par: par.clone(), + par: par.to_owned(), coeffs: coeffs.into_boxed_slice(), } } @@ -58,7 +52,7 @@ impl SecretKey { // Let's create a secret key with the ciphertext context let mut s = Zeroizing::new(Poly::try_convert_from( self.coeffs.as_ref(), - ct.c[0].ctx(), + ct[0].ctx(), false, Representation::PowerBasis, )?); @@ -66,11 +60,11 @@ impl SecretKey { let mut si = s.clone(); // Let's disable variable time computations - let mut c = Zeroizing::new(ct.c[0].clone()); + let mut c = Zeroizing::new(ct[0].clone()); c.disallow_variable_time_computations(); - for i in 1..ct.c.len() { - let mut cis = Zeroizing::new(ct.c[i].clone()); + for i in 1..ct.len() { + let mut cis = Zeroizing::new(ct[i].clone()); cis.disallow_variable_time_computations(); *cis.as_mut() *= si.as_ref(); *c.as_mut() += &cis; @@ -79,7 +73,7 @@ impl SecretKey { *c.as_mut() -= &m; c.change_representation(Representation::PowerBasis); - let ciphertext_modulus = ct.c[0].ctx().modulus(); + let ciphertext_modulus = ct[0].ctx().modulus(); let mut noise = 0usize; for coeff in Vec::::from(c.as_ref()) { noise = std::cmp::max( @@ -165,24 +159,24 @@ impl FheDecrypter for SecretKey { // Let's create a secret key with the ciphertext context let mut s = Zeroizing::new(Poly::try_convert_from( self.coeffs.as_ref(), - ct.c[0].ctx(), + ct[0].ctx(), false, Representation::PowerBasis, )?); s.change_representation(Representation::Ntt); let mut si = s.clone(); - let mut c = Zeroizing::new(ct.c[0].clone()); + let mut c = Zeroizing::new(ct[0].clone()); c.disallow_variable_time_computations(); // Compute the phase c0 + c1*s + c2*s^2 + ... where the secret power // s^k is computed on-the-fly - for i in 1..ct.c.len() { - let mut cis = Zeroizing::new(ct.c[i].clone()); + for i in 1..ct.len() { + let mut cis = Zeroizing::new(ct[i].clone()); cis.disallow_variable_time_computations(); *cis.as_mut() *= si.as_ref(); *c.as_mut() += &cis; - if i + 1 < ct.c.len() { + if i + 1 < ct.len() { *si.as_mut() *= s.as_ref(); } } @@ -194,7 +188,7 @@ impl FheDecrypter for SecretKey { let v = Zeroizing::new( Vec::::from(d.as_ref()) .iter_mut() - .map(|vi| *vi + self.par.plaintext.modulus()) + .map(|vi| *vi + *self.par.plaintext) .collect_vec(), ); let mut w = v[..self.par.degree()].to_vec(); @@ -203,7 +197,7 @@ impl FheDecrypter for SecretKey { self.par.plaintext.reduce_vec(&mut w); let mut poly = - Poly::try_convert_from(&w, ct.c[0].ctx(), false, Representation::PowerBasis)?; + Poly::try_convert_from(&w, ct[0].ctx(), false, Representation::PowerBasis)?; poly.change_representation(Representation::Ntt); let pt = Plaintext { diff --git a/crates/fhe/src/bfv/ops/dot_product.rs b/crates/fhe/src/bfv/ops/dot_product.rs index b80e7fb..746e163 100644 --- a/crates/fhe/src/bfv/ops/dot_product.rs +++ b/crates/fhe/src/bfv/ops/dot_product.rs @@ -63,14 +63,14 @@ where )); } let ct_first = ct.clone().next().unwrap(); - let ctx = ct_first.c[0].ctx(); + let ctx = ct_first[0].ctx(); if izip!(ct.clone(), pt.clone()).any(|(cti, pti)| { - cti.par != ct_first.par || pti.par != ct_first.par || cti.c.len() != ct_first.c.len() + cti.par != ct_first.par || pti.par != ct_first.par || cti.len() != ct_first.len() }) { return Err(Error::DefaultError("Mismatched parameters".to_string())); } - if ct.clone().any(|cti| cti.c.len() != ct_first.c.len()) { + if ct.clone().any(|cti| cti.len() != ct_first.len()) { return Err(Error::DefaultError( "Mismatched number of parts in the ciphertexts".to_string(), )); @@ -86,10 +86,10 @@ where if count as u128 > *min_of_max { // Too many ciphertexts for the optimized method, instead, we call // `poly_dot_product`. - let c = (0..ct_first.c.len()) + let c = (0..ct_first.len()) .map(|i| { poly_dot_product( - ct.clone().map(|cti| unsafe { cti.c.get_unchecked(i) }), + ct.clone().map(|cti| unsafe { cti.get_unchecked(i) }), pt.clone().map(|pti| &pti.poly_ntt), ) .map_err(Error::MathError) @@ -103,10 +103,10 @@ where level: ct_first.level, }) } else { - let mut acc = Array::zeros((ct_first.c.len(), ctx.moduli().len(), ct_first.par.degree())); + let mut acc = Array::zeros((ct_first.len(), ctx.moduli().len(), ct_first.par.degree())); for (ciphertext, plaintext) in izip!(ct, pt) { let pt_coefficients = plaintext.poly_ntt.coefficients(); - for (mut acci, ci) in izip!(acc.outer_iter_mut(), ciphertext.c.iter()) { + for (mut acci, ci) in izip!(acc.outer_iter_mut(), ciphertext.iter()) { let ci_coefficients = ci.coefficients(); for (mut accij, cij, pij) in izip!( acci.outer_iter_mut(), @@ -125,7 +125,7 @@ where } // Reduce - let mut c = Vec::with_capacity(ct_first.c.len()); + let mut c = Vec::with_capacity(ct_first.len()); for acci in acc.outer_iter() { let mut coeffs = Array2::zeros((ctx.moduli().len(), ct_first.par.degree())); for (mut outij, accij, q) in izip!( diff --git a/crates/fhe/src/bfv/ops/mod.rs b/crates/fhe/src/bfv/ops/mod.rs index 2e0bce0..a6b6cc5 100644 --- a/crates/fhe/src/bfv/ops/mod.rs +++ b/crates/fhe/src/bfv/ops/mod.rs @@ -9,7 +9,7 @@ pub use mul::Multiplicator; use super::{Ciphertext, Plaintext}; use crate::{Error, Result}; use fhe_math::rq::{Poly, Representation}; -use itertools::{izip, Itertools}; +use itertools::{izip, Itertools as _}; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; impl Add<&Ciphertext> for &Ciphertext { @@ -26,12 +26,12 @@ impl AddAssign<&Ciphertext> for Ciphertext { fn add_assign(&mut self, rhs: &Ciphertext) { assert_eq!(self.par, rhs.par); - if self.c.is_empty() { + if self.is_empty() { *self = rhs.clone() - } else if !rhs.c.is_empty() { + } else if !rhs.is_empty() { assert_eq!(self.level, rhs.level); - assert_eq!(self.c.len(), rhs.c.len()); - izip!(&mut self.c, &rhs.c).for_each(|(c1i, c2i)| *c1i += c2i); + assert_eq!(self.len(), rhs.len()); + izip!(self.iter_mut(), rhs.iter()).for_each(|(c1i, c2i)| *c1i += c2i); self.seed = None } } @@ -58,11 +58,11 @@ impl Add<&Ciphertext> for &Plaintext { impl AddAssign<&Plaintext> for Ciphertext { fn add_assign(&mut self, rhs: &Plaintext) { assert_eq!(self.par, rhs.par); - assert!(!self.c.is_empty()); + assert!(!self.is_empty()); assert_eq!(self.level, rhs.level); let poly = rhs.to_poly(); - self.c[0] += &poly; + self[0] += &poly; self.seed = None } } @@ -81,12 +81,12 @@ impl SubAssign<&Ciphertext> for Ciphertext { fn sub_assign(&mut self, rhs: &Ciphertext) { assert_eq!(self.par, rhs.par); - if self.c.is_empty() { + if self.is_empty() { *self = -rhs - } else if !rhs.c.is_empty() { + } else if !rhs.is_empty() { assert_eq!(self.level, rhs.level); - assert_eq!(self.c.len(), rhs.c.len()); - izip!(&mut self.c, &rhs.c).for_each(|(c1i, c2i)| *c1i -= c2i); + assert_eq!(self.len(), rhs.len()); + izip!(self.iter_mut(), rhs.iter()).for_each(|(c1i, c2i)| *c1i -= c2i); self.seed = None } } @@ -113,7 +113,7 @@ impl Sub<&Ciphertext> for &Plaintext { impl SubAssign<&Plaintext> for Ciphertext { fn sub_assign(&mut self, rhs: &Plaintext) { assert_eq!(self.par, rhs.par); - assert!(!self.c.is_empty()); + assert!(!self.is_empty()); assert_eq!(self.level, rhs.level); let poly = rhs.to_poly(); @@ -126,7 +126,7 @@ impl Neg for &Ciphertext { type Output = Ciphertext; fn neg(self) -> Ciphertext { - let c = self.c.iter().map(|c1i| -c1i).collect_vec(); + let c = self.iter().map(|c1i| -c1i).collect_vec(); Ciphertext { par: self.par.clone(), seed: None, @@ -140,7 +140,7 @@ impl Neg for Ciphertext { type Output = Ciphertext; fn neg(mut self) -> Ciphertext { - self.c.iter_mut().for_each(|c1i| *c1i = -&*c1i); + self.iter_mut().for_each(|c1i| *c1i = -&*c1i); self.seed = None; self } @@ -149,9 +149,9 @@ impl Neg for Ciphertext { impl MulAssign<&Plaintext> for Ciphertext { fn mul_assign(&mut self, rhs: &Plaintext) { assert_eq!(self.par, rhs.par); - if !self.c.is_empty() { + if !self.is_empty() { assert_eq!(self.level, rhs.level); - self.c.iter_mut().for_each(|ci| *ci *= &rhs.poly_ntt); + self.iter_mut().for_each(|ci| *ci *= &rhs.poly_ntt); } self.seed = None } @@ -171,7 +171,7 @@ impl Mul<&Ciphertext> for &Ciphertext { type Output = Ciphertext; fn mul(self, rhs: &Ciphertext) -> Ciphertext { - if self.c.is_empty() { + if self.is_empty() { return self.clone(); } @@ -182,7 +182,6 @@ impl Mul<&Ciphertext> for &Ciphertext { // Scale all ciphertexts // let mut now = std::time::SystemTime::now(); let self_c = self - .c .iter() .map(|ci| ci.scale(&mp.extender).map_err(Error::MathError)) .collect::>>() @@ -228,13 +227,11 @@ impl Mul<&Ciphertext> for &Ciphertext { // Scale all ciphertexts // let mut now = std::time::SystemTime::now(); let self_c = self - .c .iter() .map(|ci| ci.scale(&mp.extender).map_err(Error::MathError)) .collect::>>() .unwrap(); let other_c = rhs - .c .iter() .map(|ci| ci.scale(&mp.extender).map_err(Error::MathError)) .collect::>>() @@ -536,7 +533,7 @@ mod tests { } } EncodingEnum::Simd => { - c = a.clone(); + c.clone_from(&a); params.plaintext.mul_vec(&mut c, &b); } } diff --git a/crates/fhe/src/bfv/ops/mul.rs b/crates/fhe/src/bfv/ops/mul.rs index b753251..28184fb 100644 --- a/crates/fhe/src/bfv/ops/mul.rs +++ b/crates/fhe/src/bfv/ops/mul.rs @@ -121,10 +121,7 @@ impl Multiplicator { ScalingFactor::one(), ScalingFactor::one(), &extended_basis, - ScalingFactor::new( - &BigUint::from(rk.ksk.par.plaintext.modulus()), - ctx.modulus(), - ), + ScalingFactor::new(&BigUint::from(*rk.ksk.par.plaintext), ctx.modulus()), rk.ksk.ciphertext_level, &rk.ksk.par, )?; @@ -170,17 +167,17 @@ impl Multiplicator { "Ciphertexts are not at expected level".to_string(), )); } - if lhs.c.len() != 2 || rhs.c.len() != 2 { + if lhs.len() != 2 || rhs.len() != 2 { return Err(Error::DefaultError( "Multiplication can only be performed on ciphertexts of size 2".to_string(), )); } // Extend - let c00 = lhs.c[0].scale(&self.extender_lhs)?; - let c01 = lhs.c[1].scale(&self.extender_lhs)?; - let c10 = rhs.c[0].scale(&self.extender_rhs)?; - let c11 = rhs.c[1].scale(&self.extender_rhs)?; + let c00 = lhs[0].scale(&self.extender_lhs)?; + let c01 = lhs[1].scale(&self.extender_lhs)?; + let c10 = rhs[0].scale(&self.extender_rhs)?; + let c11 = rhs[1].scale(&self.extender_rhs)?; // Multiply let mut c0 = &c00 * &c10; @@ -230,7 +227,7 @@ impl Multiplicator { if self.mod_switch { c.mod_switch_to_next_level()?; } else { - c.c.iter_mut() + c.iter_mut() .for_each(|p| p.change_representation(Representation::Ntt)); } diff --git a/crates/fhe/src/bfv/parameters.rs b/crates/fhe/src/bfv/parameters.rs index 600e926..34b24cf 100644 --- a/crates/fhe/src/bfv/parameters.rs +++ b/crates/fhe/src/bfv/parameters.rs @@ -158,27 +158,6 @@ impl BfvParameters { 0x1ffffffe48001, ], ); - n_and_qs.insert( - 32768, - vec![ - 0x7fffffffe90001, - 0x7fffffffbf0001, - 0x7fffffffbd0001, - 0x7fffffffba0001, - 0x7fffffffaa0001, - 0x7fffffffa50001, - 0x7fffffff9f0001, - 0x7fffffff7e0001, - 0x7fffffff770001, - 0x7fffffff380001, - 0x7fffffff330001, - 0x7fffffff2d0001, - 0x7fffffff170001, - 0x7fffffff150001, - 0x7ffffffef00001, - 0xfffffffff70001, - ], - ); let mut params = vec![]; @@ -258,7 +237,7 @@ impl BfvParametersBuilder { /// Only one of `set_moduli_sizes` and `set_moduli` /// can be specified. pub fn set_moduli_sizes(&mut self, sizes: &[usize]) -> &mut Self { - self.ciphertext_moduli_sizes = sizes.to_owned(); + sizes.clone_into(&mut self.ciphertext_moduli_sizes); self } @@ -266,7 +245,7 @@ impl BfvParametersBuilder { /// Only one of `set_moduli_sizes` and `set_moduli` /// can be specified. pub fn set_moduli(&mut self, moduli: &[u64]) -> &mut Self { - self.ciphertext_moduli = moduli.to_owned(); + moduli.clone_into(&mut self.ciphertext_moduli); self } @@ -370,7 +349,7 @@ impl BfvParametersBuilder { let mut delta_rests = vec![]; for m in &moduli { let q = Modulus::new(*m)?; - delta_rests.push(q.inv(q.neg(plaintext_modulus.modulus())).unwrap()) + delta_rests.push(q.inv(q.neg(*plaintext_modulus)).unwrap()) } let mut ctx = Vec::with_capacity(moduli.len()); @@ -390,16 +369,12 @@ impl BfvParametersBuilder { p.change_representation(Representation::NttShoup); delta.push(p); - q_mod_t.push( - (rns.modulus() % plaintext_modulus.modulus()) - .to_u64() - .unwrap(), - ); + q_mod_t.push((rns.modulus() % *plaintext_modulus).to_u64().unwrap()); scalers.push(Scaler::new( &ctx_i, &plaintext_ctx, - ScalingFactor::new(&BigUint::from(plaintext_modulus.modulus()), rns.modulus()), + ScalingFactor::new(&BigUint::from(*plaintext_modulus), rns.modulus()), )?); // For the first multiplication, we want to extend to a context that @@ -414,7 +389,7 @@ impl BfvParametersBuilder { &ctx_i, &mul_1_ctx, ScalingFactor::one(), - ScalingFactor::new(&BigUint::from(plaintext_modulus.modulus()), ctx_i.modulus()), + ScalingFactor::new(&BigUint::from(*plaintext_modulus), ctx_i.modulus()), )?); ctx.push(ctx_i); @@ -440,17 +415,17 @@ impl BfvParametersBuilder { Ok(BfvParameters { polynomial_degree: self.degree, plaintext_modulus: self.plaintext, - moduli: moduli.into_boxed_slice(), - moduli_sizes: moduli_sizes.into_boxed_slice(), + moduli: moduli.into(), + moduli_sizes: moduli_sizes.into(), variance: self.variance, ctx, op: op.map(Arc::new), - delta: delta.into_boxed_slice(), - q_mod_t: q_mod_t.into_boxed_slice(), - scalers: scalers.into_boxed_slice(), + delta: delta.into(), + q_mod_t: q_mod_t.into(), + scalers: scalers.into(), plaintext: plaintext_modulus, - mul_params: mul_params.into_boxed_slice(), - matrix_reps_index_map: matrix_reps_index_map.into_boxed_slice(), + mul_params: mul_params.into(), + matrix_reps_index_map: matrix_reps_index_map.into(), }) } } diff --git a/crates/fhe/src/bfv/plaintext.rs b/crates/fhe/src/bfv/plaintext.rs index ef37d7d..52f8cc5 100644 --- a/crates/fhe/src/bfv/plaintext.rs +++ b/crates/fhe/src/bfv/plaintext.rs @@ -6,22 +6,26 @@ use crate::{ use fhe_math::rq::{traits::TryConvertFrom, Context, Poly, Representation}; use fhe_traits::{FheDecoder, FheEncoder, FheParametrized, FhePlaintext}; use std::sync::Arc; -use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing}; +use zeroize::{Zeroize as _, Zeroizing}; +use zeroize_derive::{Zeroize, ZeroizeOnDrop}; use super::encoding::EncodingEnum; /// A plaintext object, that encodes a vector according to a specific encoding. -#[derive(Debug, Clone, Eq)] +#[derive(Debug, Clone, Eq, Zeroize, ZeroizeOnDrop)] pub struct Plaintext { /// The parameters of the underlying BFV encryption scheme. + #[zeroize(skip)] pub(crate) par: Arc, /// The value after encoding. pub(crate) value: Box<[u64]>, /// The encoding of the plaintext, if known + #[zeroize(skip)] pub(crate) encoding: Option, /// The plaintext as a polynomial. pub(crate) poly_ntt: Poly, /// The level of the plaintext + #[zeroize(skip)] pub(crate) level: usize, } @@ -33,16 +37,6 @@ impl FhePlaintext for Plaintext { type Encoding = Encoding; } -// Zeroizing of plaintexts. -impl ZeroizeOnDrop for Plaintext {} - -impl Zeroize for Plaintext { - fn zeroize(&mut self) { - self.value.zeroize(); - self.poly_ntt.zeroize(); - } -} - impl Plaintext { pub(crate) fn to_poly(&self) -> Poly { let mut m_v = Zeroizing::new(self.value.clone()); @@ -153,7 +147,7 @@ impl<'a> FheEncoder<&'a [u64]> for Plaintext { return Err(Error::TooManyValues(value.len(), par.degree())); } let v = PlaintextVec::try_encode(value, encoding, par)?; - Ok(v.0[0].clone()) + Ok(v[0].clone()) } } diff --git a/crates/fhe/src/bfv/plaintext_vec.rs b/crates/fhe/src/bfv/plaintext_vec.rs index 8fdecc1..74d53fa 100644 --- a/crates/fhe/src/bfv/plaintext_vec.rs +++ b/crates/fhe/src/bfv/plaintext_vec.rs @@ -1,8 +1,8 @@ -use std::{cmp::min, sync::Arc}; +use std::{cmp::min, ops::Deref, sync::Arc}; use fhe_math::rq::{traits::TryConvertFrom, Poly, Representation}; use fhe_traits::{FheEncoder, FheEncoderVariableTime, FheParametrized, FhePlaintext}; -use zeroize::{Zeroize, ZeroizeOnDrop}; +use zeroize_derive::{Zeroize, ZeroizeOnDrop}; use crate::{ bfv::{BfvParameters, Encoding, Plaintext}, @@ -13,7 +13,16 @@ use super::encoding::EncodingEnum; /// A wrapper around a vector of plaintext which implements the [`FhePlaintext`] /// trait, and therefore can be encoded to / decoded from. -pub struct PlaintextVec(pub Vec); +#[derive(Zeroize, ZeroizeOnDrop)] +pub struct PlaintextVec(Vec<Plaintext>); + +impl Deref for PlaintextVec { + type Target = [Plaintext]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} impl FhePlaintext for PlaintextVec { type Encoding = Encoding; @@ -23,14 +32,6 @@ impl FheParametrized for PlaintextVec { type Parameters = BfvParameters; } -impl Zeroize for PlaintextVec { - fn zeroize(&mut self) { - self.0.zeroize() - } -} - -impl ZeroizeOnDrop for PlaintextVec {} - impl FheEncoderVariableTime<&[u64]> for PlaintextVec { type Error = Error; @@ -72,7 +73,7 @@ impl FheEncoderVariableTime<&[u64]> for PlaintextVec { Ok(Plaintext { par: par.clone(), - value: v.into_boxed_slice(), + value: v.into(), encoding: Some(encoding.clone()), poly_ntt: poly, level: encoding.level, @@ -119,7 +120,7 @@ impl FheEncoder<&[u64]> for PlaintextVec { Ok(Plaintext { par: par.clone(), - value: v.into_boxed_slice(), + value: v.into(), encoding: Some(encoding.clone()), poly_ntt: poly, level: encoding.level, diff --git a/crates/fhe/src/bfv/rgsw_ciphertext.rs b/crates/fhe/src/bfv/rgsw_ciphertext.rs index 626fb9b..b67d71a 100644 --- a/crates/fhe/src/bfv/rgsw_ciphertext.rs +++ b/crates/fhe/src/bfv/rgsw_ciphertext.rs @@ -119,10 +119,10 @@ impl Mul<&RGSWCiphertext> for &Ciphertext { self.level, rhs.ksk0.ciphertext_level, "Ciphertext and RGSWCiphertext must have the same level" ); - assert_eq!(self.c.len(), 2, "Ciphertext must have two parts"); + assert_eq!(self.len(), 2, "Ciphertext must have two parts"); - let mut ct0 = self.c[0].clone(); - let mut ct1 = self.c[1].clone(); + let mut ct0 = self[0].clone(); + let mut ct1 = self[1].clone(); ct0.change_representation(Representation::PowerBasis); ct1.change_representation(Representation::PowerBasis); diff --git a/crates/fhe/src/mbfv/public_key_gen.rs b/crates/fhe/src/mbfv/public_key_gen.rs index cb77839..618338e 100644 --- a/crates/fhe/src/mbfv/public_key_gen.rs +++ b/crates/fhe/src/mbfv/public_key_gen.rs @@ -82,12 +82,15 @@ impl Aggregate<PublicKeyShare> for PublicKey { #[cfg(test)] mod tests { - use super::*; - use fhe_traits::{FheEncoder, FheEncrypter}; use rand::thread_rng; - use crate::bfv::{BfvParameters, Encoding, Plaintext, SecretKey}; + use crate::{ + bfv::{BfvParameters, Encoding, Plaintext, PublicKey, SecretKey}, + mbfv::{Aggregate as _, CommonRandomPoly}, + }; + + use super::PublicKeyShare; const NUM_PARTIES: usize = 11; diff --git a/crates/fhe/src/mbfv/public_key_switch.rs b/crates/fhe/src/mbfv/public_key_switch.rs index e4fd84f..1da2d89 100644 --- a/crates/fhe/src/mbfv/public_key_switch.rs +++ b/crates/fhe/src/mbfv/public_key_switch.rs @@ -64,14 +64,14 @@ impl PublicKeySwitchShare { let e0 = Zeroizing::new(Poly::small(ctx, Representation::Ntt, par.variance, rng)?); let e1 = Zeroizing::new(Poly::small(ctx, Representation::Ntt, par.variance, rng)?); - let mut h0 = pk_ct.c[0].clone(); + let mut h0 = pk_ct[0].clone(); h0.disallow_variable_time_computations(); h0 *= u.as_ref(); - *s.as_mut() *= &ct.c[1]; + *s.as_mut() *= &ct[1]; h0 += s.as_ref(); h0 += e0.as_ref(); - let mut h1 = pk_ct.c[1].clone(); + let mut h1 = pk_ct[1].clone(); h1.disallow_variable_time_computations(); h1 *= u.as_ref(); h1 += e1.as_ref(); @@ -83,7 +83,7 @@ impl PublicKeySwitchShare { Ok(Self { par, - c0: ct.c[0].clone(), + c0: ct[0].clone(), h0_share: h0, h1_share: h1, }) @@ -118,12 +118,10 @@ mod tests { use rand::thread_rng; use crate::{ - bfv::{BfvParameters, Encoding, Plaintext, SecretKey}, - mbfv::{AggregateIter, CommonRandomPoly, PublicKeyShare}, + bfv::{BfvParameters, Encoding, Plaintext, PublicKey, SecretKey}, + mbfv::{AggregateIter, CommonRandomPoly, PublicKeyShare, PublicKeySwitchShare}, }; - use super::*; - const NUM_PARTIES: usize = 11; struct Party { diff --git a/crates/fhe/src/mbfv/relin_key_gen.rs b/crates/fhe/src/mbfv/relin_key_gen.rs index 6adf2d3..004a2ce 100644 --- a/crates/fhe/src/mbfv/relin_key_gen.rs +++ b/crates/fhe/src/mbfv/relin_key_gen.rs @@ -368,13 +368,17 @@ impl Aggregate<RelinKeyShare<R2>> for RelinearizationKey { #[cfg(test)] mod tests { - use super::*; + use std::sync::Arc; + use fhe_traits::{FheDecoder, FheEncoder, FheEncrypter}; use rand::thread_rng; use crate::{ - bfv::{BfvParameters, Encoding, Multiplicator, Plaintext, PublicKey}, - mbfv::{AggregateIter, DecryptionShare, PublicKeyShare}, + bfv::{BfvParameters, Encoding, Multiplicator, Plaintext, PublicKey, SecretKey}, + mbfv::{ + Aggregate as _, AggregateIter, CommonRandomPoly, DecryptionShare, PublicKeyShare, + RelinKeyGenerator, + }, }; const NUM_PARTIES: usize = 5; @@ -443,7 +447,7 @@ mod tests { multiplicator.enable_mod_switching().unwrap(); } let ct = Arc::new(multiplicator.multiply(&ct1, &ct2).unwrap()); - assert_eq!(ct.c.len(), 2); + assert_eq!(ct.len(), 2); // Parties perform a collective decryption let pt = party_sks diff --git a/crates/fhe/src/mbfv/secret_key_switch.rs b/crates/fhe/src/mbfv/secret_key_switch.rs index 68381bd..68cb864 100644 --- a/crates/fhe/src/mbfv/secret_key_switch.rs +++ b/crates/fhe/src/mbfv/secret_key_switch.rs @@ -49,21 +49,21 @@ impl SecretKeySwitchShare { )); } // Note: M-BFV implementation only supports ciphertext of length 2 - if ct.c.len() != 2 { - return Err(Error::TooManyValues(ct.c.len(), 2)); + if ct.len() != 2 { + return Err(Error::TooManyValues(ct.len(), 2)); } let par = sk_input_share.par.clone(); let mut s_in = Zeroizing::new(Poly::try_convert_from( sk_input_share.coeffs.as_ref(), - ct.c[0].ctx(), + ct[0].ctx(), false, Representation::PowerBasis, )?); s_in.change_representation(Representation::Ntt); let mut s_out = Zeroizing::new(Poly::try_convert_from( sk_output_share.coeffs.as_ref(), - ct.c[0].ctx(), + ct[0].ctx(), false, Representation::PowerBasis, )?); @@ -72,7 +72,7 @@ impl SecretKeySwitchShare { // Sample error // TODO this should be exponential in ciphertext noise! let e = Zeroizing::new(Poly::small( - ct.c[0].ctx(), + ct[0].ctx(), Representation::Ntt, par.variance, rng, @@ -81,7 +81,7 @@ impl SecretKeySwitchShare { // Create h_i share let mut h_share = s_in.as_ref() - s_out.as_ref(); h_share.disallow_variable_time_computations(); - h_share *= &ct.c[1]; + h_share *= &ct[1]; h_share += e.as_ref(); Ok(Self { par, ct, h_share }) @@ -100,8 +100,8 @@ impl Aggregate<SecretKeySwitchShare> for Ciphertext { h += &sh.h_share; } - let c0 = &share.ct.c[0] + &h; - let c1 = share.ct.c[1].clone(); + let c0 = &share.ct[0] + &h; + let c1 = share.ct[1].clone(); Ciphertext::new(vec![c0, c1], &share.par) } @@ -142,32 +142,30 @@ impl Aggregate<DecryptionShare> for Plaintext { { let sks_shares = iter.into_iter().map(|s| s.sks_share); let ct = Ciphertext::from_shares(sks_shares)?; - let par = ct.par; // Note: during SKS, c[1]*sk has already been added to c[0]. - let mut c = Zeroizing::new(ct.c[0].clone()); + let mut c = Zeroizing::new(ct[0].clone()); c.disallow_variable_time_computations(); c.change_representation(Representation::PowerBasis); // The true decryption part is done during SKS; all that is left is to scale - let d = Zeroizing::new(c.scale(&par.scalers[ct.level])?); + let d = Zeroizing::new(c.scale(&ct.par.scalers[ct.level])?); let v = Zeroizing::new( Vec::<u64>::from(d.as_ref()) .iter_mut() - .map(|vi| *vi + par.plaintext.modulus()) + .map(|vi| *vi + *ct.par.plaintext) .collect_vec(), ); - let mut w = v[..par.degree()].to_vec(); - let q = Modulus::new(par.moduli[0]).map_err(Error::MathError)?; + let mut w = v[..ct.par.degree()].to_vec(); + let q = Modulus::new(ct.par.moduli[0]).map_err(Error::MathError)?; q.reduce_vec(&mut w); - par.plaintext.reduce_vec(&mut w); + ct.par.plaintext.reduce_vec(&mut w); - let mut poly = - Poly::try_convert_from(&w, ct.c[0].ctx(), false, Representation::PowerBasis)?; + let mut poly = Poly::try_convert_from(&w, ct[0].ctx(), false, Representation::PowerBasis)?; poly.change_representation(Representation::Ntt); let pt = Plaintext { - par: par.clone(), + par: ct.par.clone(), value: w.into_boxed_slice(), encoding: None, poly_ntt: poly, @@ -187,11 +185,12 @@ mod tests { use crate::{ bfv::{BfvParameters, Encoding, Plaintext, PublicKey, SecretKey}, - mbfv::{Aggregate, AggregateIter, CommonRandomPoly, PublicKeyShare}, + mbfv::{ + Aggregate, AggregateIter, CommonRandomPoly, DecryptionShare, PublicKeyShare, + SecretKeySwitchShare, + }, }; - use super::*; - const NUM_PARTIES: usize = 11; struct Party {