diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..1b34ec4 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,11 @@ +name: Test +on: [pull_request] + +jobs: + test: + name: cargo test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - run: cargo test --release --features "tfhe/x86_64-unix" -- --test-threads=1 diff --git a/Cargo.toml b/Cargo.toml index a61c9b5..930e05c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -tfhe = { version = "0.4.1", features = [ "boolean", "shortint", "integer", "aarch64-unix" ] } +tfhe = { version = "0.4.1", features = [ "boolean", "shortint", "integer" ] } serde = { version = "1.0", features = ["derive"] } rayon = "1.8" env_logger = "0.10.0" diff --git a/README.md b/README.md index b7e95b1..a4cebd1 100644 --- a/README.md +++ b/README.md @@ -70,38 +70,9 @@ cargo doc --no-deps --open cargo test --doc --release -- --show-output ``` -## State of this project - +## Acknowledgements This project has been developed for the [Zama Bounty Program](https://github.com/zama-ai/bounty-program), specifically for the bounty ["Create a string library that works on encrypted data using TFHE-rs"](https://github.com/zama-ai/bounty-program/issues/80). -### Deviations from bounty description - -We chose to develop this library under the principle **"everything encrypted first"**. This means that support for operations with encrypted inputs has been prioritized over support for operations where parts of the input (e.g., the pattern) is not encrypted, or where the strings are encrypted in a way that leaks their length. -In the following we list some aspects in which our implementation deviates from the bounty description. - -#### Cleartext input - -- *No optimizations for unpadded strings:* The original bounty description stated that all strings should be 0-padded. Later, this requirement was relaxed (see note in [bounty description](https://github.com/zama-ai/bounty-program/issues/80)) to allow for unpadded strings that are indentifiable as such without decryption. Due to time constraints, unpadded strings, or any optimizations in that regard, are currently not implemented. However, we do list potential optimizations further below. - -- *No optimizations for partial cleartext input:* We did not implement a dedicated cleartext API or any optimizations for it. We support these operations by first encrypting the cleartext inputs and then calling the corresponding ciphertext API. - -#### Project structure - -- *String functions implemented on `FheString` instead of `ServerKey`:* The bounty description asks for the string functions to be implemented on the server key type. However, we found it to be more intuitive to have the functions on the `FheString` type, similar to how regular string functions are available on their string type. (Obviously, this can easily be changed on request.) - -- *Standalone library instead of `tfhe-rs` example:* The bounty description asks for the code be provided as an example that is part of the `tfhe-rs` codebase. However, we found that compilation times were much longer when compiling the code in form of an example compared compiling it as a standalone library. As this was limiting code iteration time, we decided to develop and provide the code in form of a standalone library. (Obviously, this can easily be changed on request.) - -#### String length - -- *Restricted to strings of length < 256:* Currently, the library does not support encrypted strings longer than 255 characters. This is due to the fact that for our `FheString` algorithms to work, we need to be able to represent encrypted integers up to the maximum string length. The size of encrypted integers is fixed at key generation. We could have opted for supporting longer strings (in fact, this is an easy change to the key generation function), but we felt that 256 characters is more than enough initially, considering the limited performance. - -### Potential optimizations - -Currently, all encrypted strings are 0-padded. -In the following, we outline a number of potential optimizations that could be applied if support for unpadded encrypted strings is added in the future. +## License -- `ends_with`: currently need to go through whole string because we don't know - length. then only need to compare the respective ends of the two encrypted - strings. -- `add`, `repeat`: currently this is a quadratic operation because we don't know - where the boundaries are. if we don't have padding, we can just append. +TBD \ No newline at end of file diff --git a/examples/cmd/main.rs b/examples/cmd/main.rs index 96e8815..458aeea 100644 --- a/examples/cmd/main.rs +++ b/examples/cmd/main.rs @@ -504,6 +504,7 @@ fn main() { }, ]; + let start = Instant::now(); test_cases.iter().for_each(|t| { let start = Instant::now(); let result_std = (t.std)(&args); @@ -525,6 +526,8 @@ fn main() { } ) }); + let duration = start.elapsed(); + println!("\nDuration (total): {:?}", duration); } trait TestCaseOutput: Debug { diff --git a/src/ciphertext/compare.rs b/src/ciphertext/compare.rs index b4cda3d..db1840f 100644 --- a/src/ciphertext/compare.rs +++ b/src/ciphertext/compare.rs @@ -3,13 +3,13 @@ use std::cmp; use rayon::{join, prelude::*}; -use tfhe::integer::RadixCiphertext; +use tfhe::integer::{IntegerCiphertext, RadixCiphertext}; use crate::server_key::ServerKey; use super::{ logic::{binary_and, binary_and_vec, binary_not, binary_or}, - FheAsciiChar, FheString, + FheString, }; impl FheString { @@ -23,19 +23,32 @@ impl FheString { /// Returns `self == s`. The result is an encryption of 1 if this is the /// case and an encryption of 0 otherwise. pub fn eq(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext { - // Pad to same length. - let l = cmp::max(self.max_len(), s.max_len()); - let a = self.pad(k, l); - let b = s.pad(k, l); + // Compare overlapping part. + let l = cmp::min(self.max_len(), s.max_len()); + let a = self.substr_clear(k, 0, l); + let b = s.substr_clear(k, 0, l); - // is_eq[i] = a[i] == b[i] - let is_eq = - a.0.par_iter() - .zip(b.0) - .map(|(ai, bi)| k.k.eq_parallelized(&ai.0, &bi.0)) - .collect::>(); + let (overlap_eq, overhang_empty) = join( + || { + // Convert strings to radix integers and rely on optimized comparison. + let radix_a = a.to_long_radix(); + let radix_b = b.to_long_radix(); + let eq = k.k.eq_parallelized(&radix_a, &radix_b); - binary_and_vec(k, &is_eq) + // Trim exceeding radix blocks to ensure compatibility. + k.k.trim_radix_blocks_msb(&eq, eq.blocks().len() - k.num_blocks) + }, + || { + // Ensure that overhang is empty. + match self.max_len().cmp(&s.max_len()) { + cmp::Ordering::Greater => self.substr_clear(k, l, self.max_len()).is_empty(k), + cmp::Ordering::Less => s.substr_clear(k, l, s.max_len()).is_empty(k), + cmp::Ordering::Equal => k.create_one(), + } + }, + ); + + binary_and(k, &overlap_eq, &overhang_empty) } /// Returns `self != s`. The result is an encryption of 1 if this is the @@ -86,7 +99,7 @@ impl FheString { is_lt = binary_or(k, &is_lt, &ai_lt_bi_and_eq); // is_eq = is_eq && ai == bi - is_eq = k.k.mul_parallelized(&is_eq, ai_eq_bi); + is_eq = binary_and(k, &is_eq, ai_eq_bi); }); is_lt } @@ -125,26 +138,57 @@ impl FheString { binary_and_vec(k, &v) } - /// Returns whether `self[i..i+s.len]` and `s` are equal. The result is an - /// encryption of 1 if this is the case and an encryption of 0 otherwise. + /// Returns whether `self[i..i+s.len]` and `s` are equal. pub fn substr_eq(&self, k: &ServerKey, i: usize, s: &FheString) -> RadixCiphertext { // Extract substring. + let a = self.substr_clear(k, i, self.max_len()); let b = s; - let b_len = b.len(k); - let a = self.substr_clear(k, i); - let a = a.truncate(k, &b_len); - a.eq(k, b) + + let (mut v, overhang_empty) = join( + || { + // v[i] = a[i] == b[i] && b[i] != 0 + a.0.par_iter() + .zip(&b.0) + .map(|(ai, bi)| { + let eq = k.k.eq_parallelized(&ai.0, &bi.0); + let is_term = k.k.scalar_eq_parallelized(&bi.0, Self::TERMINATOR); + k.k.bitor_parallelized(&eq, &is_term) + }) + .collect::>() + }, + || { + // If a is potentially shorter than b, ensure that overhang is empty. + match a.max_len() < b.max_len() { + true => Some(b.substr_clear(k, a.max_len(), b.max_len()).is_empty(k)), + false => None, + } + }, + ); + + if let Some(overhang_empty) = overhang_empty { + v.push(overhang_empty); + } + + // Check if all v[i] == 1. + binary_and_vec(k, &v) } - /// Returns `self[i..]`. If `i >= self.len`, returns the empty string. - fn substr_clear(&self, k: &ServerKey, i: usize) -> FheString { - let empty_string = Self::empty_string(k); - let v = self.0.get(i..).unwrap_or(&empty_string.0); + /// Returns `self[start..end]`. If `start >= self.len`, returns the empty + /// string. If `end > self.max_len`, set `end = self.max_len`. + fn substr_clear(&self, k: &ServerKey, start: usize, end: usize) -> FheString { + let end = cmp::min(self.max_len(), end); + let mut v = self.0.get(start..end).unwrap_or_default().to_vec(); + v.push(FheString::term_char(k)); FheString(v.to_vec()) } - fn empty_string(k: &ServerKey) -> Self { - let term = FheAsciiChar(k.create_value(Self::TERMINATOR)); - FheString(vec![term]) + // Converts the string into a long radix by concatenating its blocks. + fn to_long_radix(&self) -> RadixCiphertext { + let blocks: Vec<_> = self + .0 + .iter() + .flat_map(|c| c.0.blocks().to_owned()) + .collect(); + RadixCiphertext::from_blocks(blocks) } } diff --git a/src/ciphertext/convert.rs b/src/ciphertext/convert.rs index d92cf65..d7bc52c 100644 --- a/src/ciphertext/convert.rs +++ b/src/ciphertext/convert.rs @@ -5,7 +5,7 @@ use tfhe::integer::RadixCiphertext; use crate::server_key::ServerKey; -use super::{FheAsciiChar, FheString, Uint}; +use super::{logic::binary_and, FheAsciiChar, FheString, Uint}; impl FheAsciiChar { const CASE_DIFF: Uint = 32; @@ -15,7 +15,7 @@ impl FheAsciiChar { // (65 <= c <= 90) let c_geq_65 = k.k.scalar_ge_parallelized(&self.0, 65 as Uint); let c_leq_90 = k.k.scalar_le_parallelized(&self.0, 90 as Uint); - k.k.mul_parallelized(&c_geq_65, &c_leq_90) + binary_and(k, &c_geq_65, &c_leq_90) } /// Returns whether `self` is lowercase. @@ -23,7 +23,7 @@ impl FheAsciiChar { // (97 <= c <= 122) let c_geq_97 = k.k.scalar_ge_parallelized(&self.0, 97 as Uint); let c_leq_122 = k.k.scalar_le_parallelized(&self.0, 122 as Uint); - k.k.mul_parallelized(&c_geq_97, &c_leq_122) + binary_and(k, &c_geq_97, &c_leq_122) } /// Returns the lowercase representation of `self`. diff --git a/src/ciphertext/mod.rs b/src/ciphertext/mod.rs index 49d809a..736cfb6 100644 --- a/src/ciphertext/mod.rs +++ b/src/ciphertext/mod.rs @@ -136,8 +136,8 @@ impl FheString { log::trace!("len: at index {i_sub_1}"); let self_isub1 = &pair[0]; let self_i = &pair[1]; - let self_isub1_neq_0 = k.k.scalar_ne_parallelized(&self_isub1.0, 0); - let self_i_eq_0 = k.k.scalar_eq_parallelized(&self_i.0, 0); + let self_isub1_neq_0 = k.k.scalar_ne_parallelized(&self_isub1.0, Self::TERMINATOR); + let self_i_eq_0 = k.k.scalar_eq_parallelized(&self_i.0, Self::TERMINATOR); let b = binary_and(k, &self_isub1_neq_0, &self_i_eq_0); let i = i_sub_1 + 1; k.k.scalar_mul_parallelized(&b, i as Uint) @@ -252,14 +252,21 @@ impl FheString { fn pad(&self, k: &ServerKey, l: usize) -> Self { if l > Self::max_len_with_key(k) { panic!("pad length exceeds maximum length") + } else if l < self.max_len() { + // Nothing to pad. + return self.clone(); } let mut v = self.0.to_vec(); - let term = FheAsciiChar(k.create_value(Self::TERMINATOR)); + let term = Self::term_char(k); // l + 1 because of termination character. (0..l + 1 - self.0.len()).for_each(|_| v.push(term.clone())); FheString(v) } + + fn term_char(k: &ServerKey) -> FheAsciiChar { + FheAsciiChar(k.create_value(Self::TERMINATOR)) + } } /// Given `v` and `Enc(i)`, return `v[i]`. Returns `0` if `i` is out of bounds. @@ -325,14 +332,20 @@ fn index_of_unchecked_with_options( }; // Evaluate predicate `p` on each element of `v`. - let p_eval: Vec<_> = items.par_iter().map(|(i, x)| (i, p(k, x))).collect(); + let p_eval: Vec<_> = items + .par_iter() + .map(|(i, x)| { + let pi = p(k, x); + let pi_mul_i = k.k.scalar_mul_parallelized(&pi, *i as Uint); + (i, pi, pi_mul_i) + }) + .collect(); // Find first index for which predicate evaluated to 1. - p_eval.into_iter().for_each(|(i, pi)| { + p_eval.into_iter().for_each(|(i, pi, pi_mul_i)| { log::trace!("index_of_opt_unchecked: at index {i}"); // index = b ? index : (pi ? i : 0) - let pi_mul_i = k.k.scalar_mul_parallelized(&pi, *i as Uint); index = binary_if_then_else(k, &b, &index, &pi_mul_i); // b = b || pi diff --git a/src/ciphertext/search.rs b/src/ciphertext/search.rs index 91577ed..0d4f9ed 100644 --- a/src/ciphertext/search.rs +++ b/src/ciphertext/search.rs @@ -1,6 +1,6 @@ //! Functionality for string search. -use rayon::prelude::*; +use rayon::{join, prelude::*}; use tfhe::integer::RadixCiphertext; use crate::{ @@ -223,8 +223,7 @@ impl FheString { let opti = self.rfind(k, s); // is_end = self.len == i + s.len - let self_len = self.len(k); - let s_len = s.len(k); + let (self_len, s_len) = join(|| self.len(k), || s.len(k)); let i_add_s_len = k.k.add_parallelized(&opti.val, &s_len); let is_end = k.k.eq_parallelized(&self_len, &i_add_s_len); diff --git a/src/ciphertext/trim.rs b/src/ciphertext/trim.rs index 49dfe65..2ef0012 100644 --- a/src/ciphertext/trim.rs +++ b/src/ciphertext/trim.rs @@ -19,7 +19,7 @@ impl FheAsciiChar { // (9 <= c <= 13) || c == 32 let c_geq_9 = k.k.scalar_ge_parallelized(&self.0, 9 as Uint); let c_leq_13 = k.k.scalar_le_parallelized(&self.0, 13 as Uint); - let c_geq_9_and_c_leq_13 = k.k.mul_parallelized(&c_geq_9, &c_leq_13); + let c_geq_9_and_c_leq_13 = binary_and(k, &c_geq_9, &c_leq_13); let c_eq_32 = k.k.scalar_eq_parallelized(&self.0, 32 as Uint); binary_or(k, &c_geq_9_and_c_leq_13, &c_eq_32) }