Skip to content

Commit

Permalink
Merge pull request #3 from itzmeanjan/bench-pir
Browse files Browse the repository at this point in the history
Benchmark ChalametPIR
  • Loading branch information
itzmeanjan authored Jan 28, 2025
2 parents 74d2242 + 1b8be47 commit 7b4fcae
Show file tree
Hide file tree
Showing 7 changed files with 375 additions and 48 deletions.
16 changes: 16 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,22 @@ rand = "=0.8.5"
rand_chacha = "=0.3.1"
serde = { version = "=1.0.217", features = ["derive"] }
bincode = "=1.3.3"
rayon = "=1.10.0"

[dev-dependencies]
divan = "=0.1.17"

[[bench]]
name = "offline_phase"
harness = false

[[bench]]
name = "online_phase"
harness = false
required-features = ["mutate_internal_client_state"]

[features]
mutate_internal_client_state = []

[profile.optimized]
inherits = "release"
Expand Down
99 changes: 99 additions & 0 deletions benches/offline_phase.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use chalamet_pir::{client, server};
use divan;
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use std::{collections::HashMap, time::Duration};

fn main() {
divan::main();
}

fn generate_random_kv_database(rng: &mut ChaCha8Rng, num_kv_pairs: usize, key_byte_len: usize, value_byte_len: usize) -> HashMap<Vec<u8>, Vec<u8>> {
assert!(key_byte_len > 0);
assert!(value_byte_len > 0);

let mut kv = HashMap::with_capacity(num_kv_pairs);

for _ in 0..num_kv_pairs {
let mut key = vec![0u8; key_byte_len];
let mut value = vec![0u8; value_byte_len];

rng.fill_bytes(&mut key);
rng.fill_bytes(&mut value);

kv.insert(key, value);
}

kv
}

#[derive(Debug)]
struct DBConfig {
db_entry_count: usize,
mat_elem_bit_len: usize,
key_byte_len: usize,
value_byte_len: usize,
}

const ARGS: &[DBConfig] = &[
DBConfig {
db_entry_count: 1usize << 16,
mat_elem_bit_len: 10,
key_byte_len: 32,
value_byte_len: 1024,
},
DBConfig {
db_entry_count: 1usize << 18,
mat_elem_bit_len: 10,
key_byte_len: 32,
value_byte_len: 1024,
},
DBConfig {
db_entry_count: 1usize << 20,
mat_elem_bit_len: 9,
key_byte_len: 32,
value_byte_len: 1024,
},
DBConfig {
db_entry_count: 1usize << 22,
mat_elem_bit_len: 9,
key_byte_len: 32,
value_byte_len: 1024,
},
DBConfig {
db_entry_count: 1usize << 24,
mat_elem_bit_len: 8,
key_byte_len: 32,
value_byte_len: 1024,
},
];
const ARITIES: [u32; 2] = [3, 4];

#[divan::bench(args = ARGS, consts = ARITIES, max_time = Duration::from_secs(300), skip_ext_time = true)]
fn server_setup<const ARITY: u32>(bencher: divan::Bencher, db_config: &DBConfig) {
let mut rng = ChaCha8Rng::from_entropy();

let kv = generate_random_kv_database(&mut rng, db_config.db_entry_count, db_config.key_byte_len, db_config.value_byte_len);
let kv_as_ref = kv.iter().map(|(k, v)| (k.as_slice(), v.as_slice())).collect::<HashMap<&[u8], &[u8]>>();

let mut seed_μ = [0u8; server::SEED_BYTE_LEN];
rng.fill_bytes(&mut seed_μ);

bencher
.with_inputs(|| (kv_as_ref.clone(), seed_μ.clone()))
.bench_values(|(kv, seed)| server::Server::setup::<ARITY>(divan::black_box(db_config.mat_elem_bit_len), divan::black_box(&seed), divan::black_box(kv)));
}

#[divan::bench(args = ARGS, consts = ARITIES, max_time = Duration::from_secs(300), skip_ext_time = true)]
fn client_setup<const ARITY: u32>(bencher: divan::Bencher, db_config: &DBConfig) {
let mut rng = ChaCha8Rng::from_entropy();

let kv = generate_random_kv_database(&mut rng, db_config.db_entry_count, db_config.key_byte_len, db_config.value_byte_len);
let kv_as_ref = kv.iter().map(|(k, v)| (k.as_slice(), v.as_slice())).collect::<HashMap<&[u8], &[u8]>>();

let mut seed_μ = [0u8; server::SEED_BYTE_LEN];
rng.fill_bytes(&mut seed_μ);

let (_, hint_bytes, filter_param_bytes) = server::Server::setup::<ARITY>(db_config.mat_elem_bit_len, &seed_μ, kv_as_ref).expect("Server setup failed");
bencher.bench(|| client::Client::setup(divan::black_box(&seed_μ), divan::black_box(&hint_bytes), divan::black_box(&filter_param_bytes)));
}
137 changes: 137 additions & 0 deletions benches/online_phase.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use chalamet_pir::{client, server};
use divan;
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use std::{collections::HashMap, time::Duration};

fn main() {
divan::main();
}

fn generate_random_kv_database(rng: &mut ChaCha8Rng, num_kv_pairs: usize, key_byte_len: usize, value_byte_len: usize) -> HashMap<Vec<u8>, Vec<u8>> {
assert!(key_byte_len > 0);
assert!(value_byte_len > 0);

let mut kv = HashMap::with_capacity(num_kv_pairs);

for _ in 0..num_kv_pairs {
let mut key = vec![0u8; key_byte_len];
let mut value = vec![0u8; value_byte_len];

rng.fill_bytes(&mut key);
rng.fill_bytes(&mut value);

kv.insert(key, value);
}

kv
}

#[derive(Debug)]
struct DBConfig {
db_entry_count: usize,
mat_elem_bit_len: usize,
key_byte_len: usize,
value_byte_len: usize,
}

const ARGS: &[DBConfig] = &[
DBConfig {
db_entry_count: 1usize << 16,
mat_elem_bit_len: 10,
key_byte_len: 32,
value_byte_len: 1024,
},
DBConfig {
db_entry_count: 1usize << 18,
mat_elem_bit_len: 10,
key_byte_len: 32,
value_byte_len: 1024,
},
DBConfig {
db_entry_count: 1usize << 20,
mat_elem_bit_len: 9,
key_byte_len: 32,
value_byte_len: 1024,
},
DBConfig {
db_entry_count: 1usize << 22,
mat_elem_bit_len: 9,
key_byte_len: 32,
value_byte_len: 1024,
},
DBConfig {
db_entry_count: 1usize << 24,
mat_elem_bit_len: 8,
key_byte_len: 32,
value_byte_len: 1024,
},
];
const ARITIES: [u32; 2] = [3, 4];

#[divan::bench(args = ARGS, consts = ARITIES, max_time = Duration::from_secs(300), skip_ext_time = true)]
fn client_query<const ARITY: u32>(bencher: divan::Bencher, db_config: &DBConfig) {
let mut rng = ChaCha8Rng::from_entropy();

let kv = generate_random_kv_database(&mut rng, db_config.db_entry_count, db_config.key_byte_len, db_config.value_byte_len);
let kv_as_ref = kv.iter().map(|(k, v)| (k.as_slice(), v.as_slice())).collect::<HashMap<&[u8], &[u8]>>();

let mut seed_μ = [0u8; server::SEED_BYTE_LEN];
rng.fill_bytes(&mut seed_μ);

let (_, hint_bytes, filter_param_bytes) = server::Server::setup::<ARITY>(db_config.mat_elem_bit_len, &seed_μ, kv_as_ref.clone()).unwrap();
let client = client::Client::setup(&seed_μ, &hint_bytes, &filter_param_bytes).unwrap();

let (&key, _) = kv_as_ref.iter().last().unwrap();

bencher.with_inputs(|| client.clone()).bench_refs(|client| {
let _ = divan::black_box(&mut *client).query(divan::black_box(key));
client.discard_query(key);
});
}

#[divan::bench(args = ARGS, consts = ARITIES, max_time = Duration::from_secs(300), skip_ext_time = true)]
fn server_respond<const ARITY: u32>(bencher: divan::Bencher, db_config: &DBConfig) {
let mut rng = ChaCha8Rng::from_entropy();

let kv = generate_random_kv_database(&mut rng, db_config.db_entry_count, db_config.key_byte_len, db_config.value_byte_len);
let kv_as_ref = kv.iter().map(|(k, v)| (k.as_slice(), v.as_slice())).collect::<HashMap<&[u8], &[u8]>>();

let mut seed_μ = [0u8; server::SEED_BYTE_LEN];
rng.fill_bytes(&mut seed_μ);

let (server, hint_bytes, filter_param_bytes) = server::Server::setup::<ARITY>(db_config.mat_elem_bit_len, &seed_μ, kv_as_ref.clone()).unwrap();
let mut client = client::Client::setup(&seed_μ, &hint_bytes, &filter_param_bytes).unwrap();

let (&key, _) = kv_as_ref.iter().last().unwrap();
let query_bytes = client.query(key).unwrap();

bencher.bench(|| divan::black_box(&server).respond(divan::black_box(&query_bytes)));
}

#[divan::bench(args = ARGS, consts = ARITIES, max_time = Duration::from_secs(300), skip_ext_time = true)]
fn client_process_response<const ARITY: u32>(bencher: divan::Bencher, db_config: &DBConfig) {
let mut rng = ChaCha8Rng::from_entropy();

let kv = generate_random_kv_database(&mut rng, db_config.db_entry_count, db_config.key_byte_len, db_config.value_byte_len);
let kv_as_ref = kv.iter().map(|(k, v)| (k.as_slice(), v.as_slice())).collect::<HashMap<&[u8], &[u8]>>();

let mut seed_μ = [0u8; server::SEED_BYTE_LEN];
rng.fill_bytes(&mut seed_μ);

let (server, hint_bytes, filter_param_bytes) = server::Server::setup::<ARITY>(db_config.mat_elem_bit_len, &seed_μ, kv_as_ref.clone()).unwrap();
let mut client = client::Client::setup(&seed_μ, &hint_bytes, &filter_param_bytes).unwrap();

let (&key, _) = kv_as_ref.iter().last().unwrap();
let query_bytes = client.query(key).unwrap();

let query = client.discard_query(key).unwrap();
client.insert_query(key, query.clone());

let response_bytes = server.respond(&query_bytes).unwrap();

bencher.with_inputs(|| client.clone()).bench_refs(|client| {
let _ = divan::black_box(&mut *client).process_response(divan::black_box(key), divan::black_box(&response_bytes));
client.insert_query(key, query.clone());
});
}
25 changes: 20 additions & 5 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
pub use crate::pir_internals::params::SEED_BYTE_LEN;
use crate::pir_internals::{
binary_fuse_filter::{self, BinaryFuseFilter},
branch_opt_util,
matrix::Matrix,
params::{LWE_DIMENSION, SEED_BYTE_LEN},
params::LWE_DIMENSION,
serialization,
};
use std::collections::HashMap;

#[derive(Clone)]
pub struct Query {
vec_c: Matrix,
}

#[derive(Clone)]
pub struct Client<'a> {
pub_mat_a: Matrix,
hint_mat_m: Matrix,
Expand All @@ -36,6 +39,18 @@ impl<'a> Client<'a> {
})
}

#[cfg(feature = "mutate_internal_client_state")]
#[inline(always)]
pub fn discard_query(&mut self, key: &'a [u8]) -> Option<Query> {
self.pending_queries.remove(key)
}

#[cfg(feature = "mutate_internal_client_state")]
#[inline(always)]
pub fn insert_query(&mut self, key: &'a [u8], query: Query) {
self.pending_queries.insert(key, query);
}

pub fn query(&mut self, key: &'a [u8]) -> Option<Vec<u8>> {
match self.filter.arity {
3 => self.query_for_3_wise_xor_filter(key),
Expand All @@ -55,7 +70,7 @@ impl<'a> Client<'a> {
let secret_vec_num_cols = LWE_DIMENSION;
let secret_vec_s = Matrix::sample_from_uniform_ternary_dist(1, secret_vec_num_cols)?;

let error_vector_num_cols = self.pub_mat_a.get_num_cols();
let error_vector_num_cols = self.pub_mat_a.num_cols();
let error_vec_e = Matrix::sample_from_uniform_ternary_dist(1, error_vector_num_cols)?;

let mut query_vec_b = ((&secret_vec_s * &self.pub_mat_a)? + error_vec_e)?;
Expand Down Expand Up @@ -102,7 +117,7 @@ impl<'a> Client<'a> {
let secret_vec_num_cols = LWE_DIMENSION;
let secret_vec_s = Matrix::sample_from_uniform_ternary_dist(1, secret_vec_num_cols)?;

let error_vector_num_cols = self.pub_mat_a.get_num_cols();
let error_vector_num_cols = self.pub_mat_a.num_cols();
let error_vec_e = Matrix::sample_from_uniform_ternary_dist(1, error_vector_num_cols)?;

let mut query_vec_b = ((&secret_vec_s * &self.pub_mat_a)? + error_vec_e)?;
Expand Down Expand Up @@ -154,7 +169,7 @@ impl<'a> Client<'a> {
let secret_vec_c = &query.vec_c;

let response_vector = Matrix::from_bytes(response_bytes).ok()?;
if branch_opt_util::unlikely(!(response_vector.get_num_rows() == 1 && response_vector.get_num_cols() == secret_vec_c.get_num_cols())) {
if branch_opt_util::unlikely(!(response_vector.num_rows() == 1 && response_vector.num_cols() == secret_vec_c.num_cols())) {
return None;
}

Expand All @@ -165,7 +180,7 @@ impl<'a> Client<'a> {
let hashed_key = binary_fuse_filter::hash_of_key(key);
let hash = binary_fuse_filter::mix256(&hashed_key, &self.filter.seed);

let recovered_row = (0..response_vector.get_num_cols())
let recovered_row = (0..response_vector.num_cols())
.map(|idx| {
let unscaled_res = response_vector[(0, idx)].wrapping_sub(secret_vec_c[(0, idx)]);

Expand Down
2 changes: 1 addition & 1 deletion src/pir_internals/binary_fuse_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
use sha3::{Digest, Sha3_256};
use std::collections::HashMap;

#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Clone)]
pub struct BinaryFuseFilter {
pub seed: [u8; 32],
pub arity: u32,
Expand Down
Loading

0 comments on commit 7b4fcae

Please sign in to comment.