Skip to content

Commit

Permalink
Optimize comparison algorithms (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasgeihs authored Jan 24, 2024
1 parent 6f050ea commit 33033bd
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 73 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: Test
on: [pull_request]

jobs:
test:
name: cargo test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- run: cargo test --release --features "tfhe/x86_64-unix" -- --test-threads=1
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
tfhe = { version = "0.4.1", features = [ "boolean", "shortint", "integer", "aarch64-unix" ] }
tfhe = { version = "0.4.1", features = [ "boolean", "shortint", "integer" ] }
serde = { version = "1.0", features = ["derive"] }
rayon = "1.8"
env_logger = "0.10.0"
Expand Down
35 changes: 3 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,38 +70,9 @@ cargo doc --no-deps --open
cargo test --doc --release -- --show-output
```

## State of this project

## Acknowledgements
This project has been developed for the [Zama Bounty Program](https://github.com/zama-ai/bounty-program), specifically for the bounty ["Create a string library that works on encrypted data using TFHE-rs"](https://github.com/zama-ai/bounty-program/issues/80).

### Deviations from bounty description

We chose to develop this library under the principle **"everything encrypted first"**. This means that support for operations with encrypted inputs has been prioritized over support for operations where parts of the input (e.g., the pattern) is not encrypted, or where the strings are encrypted in a way that leaks their length.
In the following we list some aspects in which our implementation deviates from the bounty description.

#### Cleartext input

- *No optimizations for unpadded strings:* The original bounty description stated that all strings should be 0-padded. Later, this requirement was relaxed (see note in [bounty description](https://github.com/zama-ai/bounty-program/issues/80)) to allow for unpadded strings that are indentifiable as such without decryption. Due to time constraints, unpadded strings, or any optimizations in that regard, are currently not implemented. However, we do list potential optimizations further below.

- *No optimizations for partial cleartext input:* We did not implement a dedicated cleartext API or any optimizations for it. We support these operations by first encrypting the cleartext inputs and then calling the corresponding ciphertext API.

#### Project structure

- *String functions implemented on `FheString` instead of `ServerKey`:* The bounty description asks for the string functions to be implemented on the server key type. However, we found it to be more intuitive to have the functions on the `FheString` type, similar to how regular string functions are available on their string type. (Obviously, this can easily be changed on request.)

- *Standalone library instead of `tfhe-rs` example:* The bounty description asks for the code be provided as an example that is part of the `tfhe-rs` codebase. However, we found that compilation times were much longer when compiling the code in form of an example compared compiling it as a standalone library. As this was limiting code iteration time, we decided to develop and provide the code in form of a standalone library. (Obviously, this can easily be changed on request.)

#### String length

- *Restricted to strings of length < 256:* Currently, the library does not support encrypted strings longer than 255 characters. This is due to the fact that for our `FheString` algorithms to work, we need to be able to represent encrypted integers up to the maximum string length. The size of encrypted integers is fixed at key generation. We could have opted for supporting longer strings (in fact, this is an easy change to the key generation function), but we felt that 256 characters is more than enough initially, considering the limited performance.

### Potential optimizations

Currently, all encrypted strings are 0-padded.
In the following, we outline a number of potential optimizations that could be applied if support for unpadded encrypted strings is added in the future.
## License

- `ends_with`: currently need to go through whole string because we don't know
length. then only need to compare the respective ends of the two encrypted
strings.
- `add`, `repeat`: currently this is a quadratic operation because we don't know
where the boundaries are. if we don't have padding, we can just append.
TBD
3 changes: 3 additions & 0 deletions examples/cmd/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ fn main() {
},
];

let start = Instant::now();
test_cases.iter().for_each(|t| {
let start = Instant::now();
let result_std = (t.std)(&args);
Expand All @@ -525,6 +526,8 @@ fn main() {
}
)
});
let duration = start.elapsed();
println!("\nDuration (total): {:?}", duration);
}

trait TestCaseOutput: Debug {
Expand Down
98 changes: 71 additions & 27 deletions src/ciphertext/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
use std::cmp;

use rayon::{join, prelude::*};
use tfhe::integer::RadixCiphertext;
use tfhe::integer::{IntegerCiphertext, RadixCiphertext};

use crate::server_key::ServerKey;

use super::{
logic::{binary_and, binary_and_vec, binary_not, binary_or},
FheAsciiChar, FheString,
FheString,
};

impl FheString {
Expand All @@ -23,19 +23,32 @@ impl FheString {
/// Returns `self == s`. The result is an encryption of 1 if this is the
/// case and an encryption of 0 otherwise.
pub fn eq(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext {
// Pad to same length.
let l = cmp::max(self.max_len(), s.max_len());
let a = self.pad(k, l);
let b = s.pad(k, l);
// Compare overlapping part.
let l = cmp::min(self.max_len(), s.max_len());
let a = self.substr_clear(k, 0, l);
let b = s.substr_clear(k, 0, l);

// is_eq[i] = a[i] == b[i]
let is_eq =
a.0.par_iter()
.zip(b.0)
.map(|(ai, bi)| k.k.eq_parallelized(&ai.0, &bi.0))
.collect::<Vec<_>>();
let (overlap_eq, overhang_empty) = join(
|| {
// Convert strings to radix integers and rely on optimized comparison.
let radix_a = a.to_long_radix();
let radix_b = b.to_long_radix();
let eq = k.k.eq_parallelized(&radix_a, &radix_b);

binary_and_vec(k, &is_eq)
// Trim exceeding radix blocks to ensure compatibility.
k.k.trim_radix_blocks_msb(&eq, eq.blocks().len() - k.num_blocks)
},
|| {
// Ensure that overhang is empty.
match self.max_len().cmp(&s.max_len()) {
cmp::Ordering::Greater => self.substr_clear(k, l, self.max_len()).is_empty(k),
cmp::Ordering::Less => s.substr_clear(k, l, s.max_len()).is_empty(k),
cmp::Ordering::Equal => k.create_one(),
}
},
);

binary_and(k, &overlap_eq, &overhang_empty)
}

/// Returns `self != s`. The result is an encryption of 1 if this is the
Expand Down Expand Up @@ -86,7 +99,7 @@ impl FheString {
is_lt = binary_or(k, &is_lt, &ai_lt_bi_and_eq);

// is_eq = is_eq && ai == bi
is_eq = k.k.mul_parallelized(&is_eq, ai_eq_bi);
is_eq = binary_and(k, &is_eq, ai_eq_bi);
});
is_lt
}
Expand Down Expand Up @@ -125,26 +138,57 @@ impl FheString {
binary_and_vec(k, &v)
}

/// Returns whether `self[i..i+s.len]` and `s` are equal. The result is an
/// encryption of 1 if this is the case and an encryption of 0 otherwise.
/// Returns whether `self[i..i+s.len]` and `s` are equal.
pub fn substr_eq(&self, k: &ServerKey, i: usize, s: &FheString) -> RadixCiphertext {
// Extract substring.
let a = self.substr_clear(k, i, self.max_len());
let b = s;
let b_len = b.len(k);
let a = self.substr_clear(k, i);
let a = a.truncate(k, &b_len);
a.eq(k, b)

let (mut v, overhang_empty) = join(
|| {
// v[i] = a[i] == b[i] && b[i] != 0
a.0.par_iter()
.zip(&b.0)
.map(|(ai, bi)| {
let eq = k.k.eq_parallelized(&ai.0, &bi.0);
let is_term = k.k.scalar_eq_parallelized(&bi.0, Self::TERMINATOR);
k.k.bitor_parallelized(&eq, &is_term)
})
.collect::<Vec<_>>()
},
|| {
// If a is potentially shorter than b, ensure that overhang is empty.
match a.max_len() < b.max_len() {
true => Some(b.substr_clear(k, a.max_len(), b.max_len()).is_empty(k)),
false => None,
}
},
);

if let Some(overhang_empty) = overhang_empty {
v.push(overhang_empty);
}

// Check if all v[i] == 1.
binary_and_vec(k, &v)
}

/// Returns `self[i..]`. If `i >= self.len`, returns the empty string.
fn substr_clear(&self, k: &ServerKey, i: usize) -> FheString {
let empty_string = Self::empty_string(k);
let v = self.0.get(i..).unwrap_or(&empty_string.0);
/// Returns `self[start..end]`. If `start >= self.len`, returns the empty
/// string. If `end > self.max_len`, set `end = self.max_len`.
fn substr_clear(&self, k: &ServerKey, start: usize, end: usize) -> FheString {
let end = cmp::min(self.max_len(), end);
let mut v = self.0.get(start..end).unwrap_or_default().to_vec();
v.push(FheString::term_char(k));
FheString(v.to_vec())
}

fn empty_string(k: &ServerKey) -> Self {
let term = FheAsciiChar(k.create_value(Self::TERMINATOR));
FheString(vec![term])
// Converts the string into a long radix by concatenating its blocks.
fn to_long_radix(&self) -> RadixCiphertext {
let blocks: Vec<_> = self
.0
.iter()
.flat_map(|c| c.0.blocks().to_owned())
.collect();
RadixCiphertext::from_blocks(blocks)
}
}
6 changes: 3 additions & 3 deletions src/ciphertext/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use tfhe::integer::RadixCiphertext;

use crate::server_key::ServerKey;

use super::{FheAsciiChar, FheString, Uint};
use super::{logic::binary_and, FheAsciiChar, FheString, Uint};

impl FheAsciiChar {
const CASE_DIFF: Uint = 32;
Expand All @@ -15,15 +15,15 @@ impl FheAsciiChar {
// (65 <= c <= 90)
let c_geq_65 = k.k.scalar_ge_parallelized(&self.0, 65 as Uint);
let c_leq_90 = k.k.scalar_le_parallelized(&self.0, 90 as Uint);
k.k.mul_parallelized(&c_geq_65, &c_leq_90)
binary_and(k, &c_geq_65, &c_leq_90)
}

/// Returns whether `self` is lowercase.
pub fn is_lowercase(&self, k: &ServerKey) -> RadixCiphertext {
// (97 <= c <= 122)
let c_geq_97 = k.k.scalar_ge_parallelized(&self.0, 97 as Uint);
let c_leq_122 = k.k.scalar_le_parallelized(&self.0, 122 as Uint);
k.k.mul_parallelized(&c_geq_97, &c_leq_122)
binary_and(k, &c_geq_97, &c_leq_122)
}

/// Returns the lowercase representation of `self`.
Expand Down
25 changes: 19 additions & 6 deletions src/ciphertext/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ impl FheString {
log::trace!("len: at index {i_sub_1}");
let self_isub1 = &pair[0];
let self_i = &pair[1];
let self_isub1_neq_0 = k.k.scalar_ne_parallelized(&self_isub1.0, 0);
let self_i_eq_0 = k.k.scalar_eq_parallelized(&self_i.0, 0);
let self_isub1_neq_0 = k.k.scalar_ne_parallelized(&self_isub1.0, Self::TERMINATOR);
let self_i_eq_0 = k.k.scalar_eq_parallelized(&self_i.0, Self::TERMINATOR);
let b = binary_and(k, &self_isub1_neq_0, &self_i_eq_0);
let i = i_sub_1 + 1;
k.k.scalar_mul_parallelized(&b, i as Uint)
Expand Down Expand Up @@ -252,14 +252,21 @@ impl FheString {
fn pad(&self, k: &ServerKey, l: usize) -> Self {
if l > Self::max_len_with_key(k) {
panic!("pad length exceeds maximum length")
} else if l < self.max_len() {
// Nothing to pad.
return self.clone();
}

let mut v = self.0.to_vec();
let term = FheAsciiChar(k.create_value(Self::TERMINATOR));
let term = Self::term_char(k);
// l + 1 because of termination character.
(0..l + 1 - self.0.len()).for_each(|_| v.push(term.clone()));
FheString(v)
}

fn term_char(k: &ServerKey) -> FheAsciiChar {
FheAsciiChar(k.create_value(Self::TERMINATOR))
}
}

/// Given `v` and `Enc(i)`, return `v[i]`. Returns `0` if `i` is out of bounds.
Expand Down Expand Up @@ -325,14 +332,20 @@ fn index_of_unchecked_with_options<T: Sync>(
};

// Evaluate predicate `p` on each element of `v`.
let p_eval: Vec<_> = items.par_iter().map(|(i, x)| (i, p(k, x))).collect();
let p_eval: Vec<_> = items
.par_iter()
.map(|(i, x)| {
let pi = p(k, x);
let pi_mul_i = k.k.scalar_mul_parallelized(&pi, *i as Uint);
(i, pi, pi_mul_i)
})
.collect();

// Find first index for which predicate evaluated to 1.
p_eval.into_iter().for_each(|(i, pi)| {
p_eval.into_iter().for_each(|(i, pi, pi_mul_i)| {
log::trace!("index_of_opt_unchecked: at index {i}");

// index = b ? index : (pi ? i : 0)
let pi_mul_i = k.k.scalar_mul_parallelized(&pi, *i as Uint);
index = binary_if_then_else(k, &b, &index, &pi_mul_i);

// b = b || pi
Expand Down
5 changes: 2 additions & 3 deletions src/ciphertext/search.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Functionality for string search.
use rayon::prelude::*;
use rayon::{join, prelude::*};
use tfhe::integer::RadixCiphertext;

use crate::{
Expand Down Expand Up @@ -223,8 +223,7 @@ impl FheString {
let opti = self.rfind(k, s);

// is_end = self.len == i + s.len
let self_len = self.len(k);
let s_len = s.len(k);
let (self_len, s_len) = join(|| self.len(k), || s.len(k));
let i_add_s_len = k.k.add_parallelized(&opti.val, &s_len);
let is_end = k.k.eq_parallelized(&self_len, &i_add_s_len);

Expand Down
2 changes: 1 addition & 1 deletion src/ciphertext/trim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl FheAsciiChar {
// (9 <= c <= 13) || c == 32
let c_geq_9 = k.k.scalar_ge_parallelized(&self.0, 9 as Uint);
let c_leq_13 = k.k.scalar_le_parallelized(&self.0, 13 as Uint);
let c_geq_9_and_c_leq_13 = k.k.mul_parallelized(&c_geq_9, &c_leq_13);
let c_geq_9_and_c_leq_13 = binary_and(k, &c_geq_9, &c_leq_13);
let c_eq_32 = k.k.scalar_eq_parallelized(&self.0, 32 as Uint);
binary_or(k, &c_geq_9_and_c_leq_13, &c_eq_32)
}
Expand Down

0 comments on commit 33033bd

Please sign in to comment.