Skip to content

Commit

Permalink
Update to use the prio fork with fixedvec dp noise.
Browse files Browse the repository at this point in the history
The relevant pull request is here: divviup/libprio-rs#578.
  • Loading branch information
MxmUrw committed May 30, 2023
1 parent 7ad8faa commit f137dce
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 40 deletions.
28 changes: 20 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ janus_messages = { version = "0.5", path = "messages" }
k8s-openapi = { version = "0.18.0", features = ["v1_24"] } # keep this version in sync with what is referenced by the indirect dependency via `kube`
kube = { version = "0.82.2", default-features = false, features = ["client", "rustls-tls"] }
opentelemetry = { version = "0.19", features = ["metrics"] }
prio = { version = "0.12.1", features = ["multithreaded"] }
# prio = { version = "0.12.1", features = ["multithreaded"] }
serde = { version = "1.0.163", features = ["derive"] }
# prio = { version = "0.12.1", features = ["multithreaded"] }
prio = {git = "https://github.com/dpsa-project/libprio-rs.git", branch ="feature-fixedvec-dp", features = ["multithreaded"]}
rstest = "0.17.0"
trillium = "0.2.9"
trillium-api = { version = "0.2.0-rc.3", default-features = false }
Expand Down
12 changes: 6 additions & 6 deletions aggregator/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,25 +502,25 @@ impl<C: Clock> TaskAggregator<C> {
}

#[cfg(feature = "fpvec_bounded_l2")]
VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length } => {
VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length, noise_param } => {
let vdaf: Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI16<U15>> =
Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(2, *length)?;
Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(2, *length, *noise_param)?;
let verify_key = task.primary_vdaf_verify_key()?;
VdafOps::Prio3FixedPoint16BitBoundedL2VecSum(Arc::new(vdaf), verify_key)
}

#[cfg(feature = "fpvec_bounded_l2")]
VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { length } => {
VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { length, noise_param } => {
let vdaf: Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI32<U31>> =
Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(2, *length)?;
Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(2, *length, *noise_param)?;
let verify_key = task.primary_vdaf_verify_key()?;
VdafOps::Prio3FixedPoint32BitBoundedL2VecSum(Arc::new(vdaf), verify_key)
}

#[cfg(feature = "fpvec_bounded_l2")]
VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { length } => {
VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { length, noise_param } => {
let vdaf: Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI64<U63>> =
Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(2, *length)?;
Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(2, *length, *noise_param)?;
let verify_key = task.primary_vdaf_verify_key()?;
VdafOps::Prio3FixedPoint64BitBoundedL2VecSum(Arc::new(vdaf), verify_key)
}
Expand Down
24 changes: 12 additions & 12 deletions aggregator/src/aggregator/aggregation_job_creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,11 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
#[cfg(feature = "fpvec_bounded_l2")]
(
task::QueryType::TimeInterval,
VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length },
VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length, noise_param },
) => {
let vdaf: Arc<Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI16<U15>>> =
Arc::new(Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(
2, *length,
2, *length, *noise_param
)?);
self.create_aggregation_jobs_for_time_interval_task_no_param::<PRIO3_VERIFY_KEY_LENGTH, Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI16<U15>>>(task, vdaf)
.await
Expand All @@ -280,11 +280,11 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
#[cfg(feature = "fpvec_bounded_l2")]
(
task::QueryType::TimeInterval,
VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { length },
VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { length, noise_param },
) => {
let vdaf: Arc<Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI32<U31>>> =
Arc::new(Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(
2, *length,
2, *length, *noise_param
)?);
self.create_aggregation_jobs_for_time_interval_task_no_param::<PRIO3_VERIFY_KEY_LENGTH, Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI32<U31>>>(task, vdaf)
.await
Expand All @@ -293,11 +293,11 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
#[cfg(feature = "fpvec_bounded_l2")]
(
task::QueryType::TimeInterval,
VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { length },
VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { length, noise_param },
) => {
let vdaf: Arc<Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI64<U63>>> =
Arc::new(Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(
2, *length,
2, *length, *noise_param
)?);
self.create_aggregation_jobs_for_time_interval_task_no_param::<PRIO3_VERIFY_KEY_LENGTH, Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI64<U63>>>(task, vdaf)
.await
Expand Down Expand Up @@ -354,11 +354,11 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
#[cfg(feature = "fpvec_bounded_l2")]
(
task::QueryType::FixedSize { max_batch_size },
VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length },
VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length, noise_param },
) => {
let vdaf: Arc<Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI16<U15>>> =
Arc::new(Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(
2, *length,
2, *length, *noise_param
)?);
let max_batch_size = *max_batch_size;
self.create_aggregation_jobs_for_fixed_size_task_no_param::<PRIO3_VERIFY_KEY_LENGTH, Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI16<U15>>>(task, vdaf, max_batch_size)
Expand All @@ -368,11 +368,11 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
#[cfg(feature = "fpvec_bounded_l2")]
(
task::QueryType::FixedSize { max_batch_size },
VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { length },
VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { length, noise_param },
) => {
let vdaf: Arc<Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI32<U31>>> =
Arc::new(Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(
2, *length,
2, *length, *noise_param
)?);
let max_batch_size = *max_batch_size;
self.create_aggregation_jobs_for_fixed_size_task_no_param::<PRIO3_VERIFY_KEY_LENGTH, Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI32<U31>>>(task, vdaf, max_batch_size)
Expand All @@ -382,11 +382,11 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
#[cfg(feature = "fpvec_bounded_l2")]
(
task::QueryType::FixedSize { max_batch_size },
VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { length },
VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { length, noise_param },
) => {
let vdaf: Arc<Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI64<U63>>> =
Arc::new(Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(
2, *length,
2, *length, *noise_param
)?);
let max_batch_size = *max_batch_size;
self.create_aggregation_jobs_for_fixed_size_task_no_param::<PRIO3_VERIFY_KEY_LENGTH, Prio3FixedPointBoundedL2VecSumMultithreaded<FixedI64<U63>>>(task, vdaf, max_batch_size)
Expand Down
26 changes: 14 additions & 12 deletions core/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use reqwest::Url;
use ring::constant_time;
use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
use std::str;
#[cfg(feature = "fpvec_bounded_l2")]
use prio::flp::types::fixedpoint_l2::PrivacyParameterType;

/// HTTP header where auth tokens are provided in messages between participants.
pub const DAP_AUTH_HEADER: &str = "DAP-Auth-Token";
Expand All @@ -35,13 +37,13 @@ pub enum VdafInstance {
Prio3Histogram { buckets: Vec<u64> },
/// A `Prio3` 16-bit fixed point vector sum with bounded L2 norm.
#[cfg(feature = "fpvec_bounded_l2")]
Prio3FixedPoint16BitBoundedL2VecSum { length: usize },
Prio3FixedPoint16BitBoundedL2VecSum { length: usize, noise_param: PrivacyParameterType },
/// A `Prio3` 32-bit fixed point vector sum with bounded L2 norm.
#[cfg(feature = "fpvec_bounded_l2")]
Prio3FixedPoint32BitBoundedL2VecSum { length: usize },
Prio3FixedPoint32BitBoundedL2VecSum { length: usize, noise_param: PrivacyParameterType },
/// A `Prio3` 64-bit fixedpoint vector sum with bounded L2 norm.
#[cfg(feature = "fpvec_bounded_l2")]
Prio3FixedPoint64BitBoundedL2VecSum { length: usize },
Prio3FixedPoint64BitBoundedL2VecSum { length: usize, noise_param: PrivacyParameterType },
/// The `poplar1` VDAF. Support for this VDAF is experimental.
Poplar1 { bits: usize },

Expand Down Expand Up @@ -167,23 +169,23 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 {
// Provide the dispatched type only, don't construct a VDAF instance.
(impl match fpvec_bounded_l2 $vdaf_instance:expr, (_, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => {
match $vdaf_instance {
::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length } => {
::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length, noise_param } => {
type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded<
::fixed::FixedI16<::fixed::types::extra::U15>,
>;
const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH;
$body
}

::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { length } => {
::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { length, noise_param } => {
type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded<
::fixed::FixedI32<::fixed::types::extra::U31>,
>;
const $VERIFY_KEY_LENGTH: usize = ::janus_core::task::PRIO3_VERIFY_KEY_LENGTH;
$body
}

::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { length } => {
::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { length, noise_param } => {
type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded<
::fixed::FixedI64<::fixed::types::extra::U63>,
>;
Expand All @@ -198,10 +200,10 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 {
// Construct a VDAF instance, and provide that to the block as well.
(impl match fpvec_bounded_l2 $vdaf_instance:expr, ($vdaf:ident, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => {
match $vdaf_instance {
::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length } => {
::janus_core::task::VdafInstance::Prio3FixedPoint16BitBoundedL2VecSum { length, noise_param } => {
let $vdaf =
::prio::vdaf::prio3::Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(
2, *length,
2, *length, *noise_param
)?;
type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded<
::fixed::FixedI16<::fixed::types::extra::U15>,
Expand All @@ -210,10 +212,10 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 {
$body
}

::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { length } => {
::janus_core::task::VdafInstance::Prio3FixedPoint32BitBoundedL2VecSum { length, noise_param } => {
let $vdaf =
::prio::vdaf::prio3::Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(
2, *length,
2, *length, *noise_param
)?;
type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded<
::fixed::FixedI32<::fixed::types::extra::U31>,
Expand All @@ -222,10 +224,10 @@ macro_rules! vdaf_dispatch_impl_fpvec_bounded_l2 {
$body
}

::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { length } => {
::janus_core::task::VdafInstance::Prio3FixedPoint64BitBoundedL2VecSum { length, noise_param } => {
let $vdaf =
::prio::vdaf::prio3::Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded(
2, *length,
2, *length, *noise_param
)?;
type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSumMultithreaded<
::fixed::FixedI64<::fixed::types::extra::U63>,
Expand Down
3 changes: 2 additions & 1 deletion messages/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ hex = "0.4"
num_enum = "0.6.1"
# We can't pull prio in from the workspace because that would enable default features, and we do not
# want prio/crypto-dependencies
prio = { version = "0.12.1", default-features = false }
# prio = { version = "0.12.1", default-features = false }
prio = {git = "https://github.com/dpsa-project/libprio-rs.git", branch ="feature-fixedvec-dp", default-features = false}
rand = "0.8"
serde.workspace = true
thiserror = "1.0"
Expand Down

0 comments on commit f137dce

Please sign in to comment.