Skip to content

Commit

Permalink
Fix usage of Url::join. (#380)
Browse files Browse the repository at this point in the history
* Update integration test to serve Janus-Janus test out of /dap/ prefix.

This exercises a current bug in Janus where it can't interoperate with
an aggregator that is not serving out of the "base" of its domain (i.e.
"https://example.com/" will work, "https://example.com/dap/" will not).
The test currently fails; the following commit will fix this bug & make
the test pass.

This was missed because all existing tests of sufficient complexity to
potentially exercise a bug in aggregator endpoint handling were testing
against endpoints set to the base of a domain.

* Task: ensure aggregator endpoints end in a slash.

This is required to fix a *different* footgun in URL than the one
causing our non-base-domain heartache: if a URL doesn't end in a slash,
Url::join will, rather than joining the paths, replace the final path
element with the additional path elements.

I also update deserialization to go through the constructor -- I believe
all constructed Task values now go through the constructor. I don't fix
that the aggregator_endpoints field is pub -- hypothetically someone
that owns a Task value could mutate the field. We should probably do as
other types do & make all fields private with getters created as
necessary.

* Join relative, rather than relative, URL paths.

Janus joined absolute paths to aggregator endpoint URLs in several
locations. Url::join treats an absolute path as replacing the existing
path. This is not documented [1], nor do any of the documentation's
examples show this behavior, but the behavior is demonstrated by [2].
Due to this, Janus would remove any existing path from an aggregator
endpoint before sending requests of the affected types, effectively
breaking interoperation.

[1] https://docs.rs/url/2.2.2/url/struct.Url.html#method.join
[2] https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=6b80dd9c27d87613afda5e2e440441cd

* ClientParameters: ensure aggregator endpoints end in a slash.

We do this for the same reason we ensure the aggregator endpoints in a
Task end in a slash.
  • Loading branch information
branlwyd authored Aug 11, 2022
1 parent e27d243 commit da42381
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 58 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 38 additions & 18 deletions janus_client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,18 @@ impl ClientParameters {
/// Creates a new set of client task parameters.
pub fn new(
task_id: TaskId,
aggregator_endpoints: Vec<Url>,
mut aggregator_endpoints: Vec<Url>,
min_batch_duration: Duration,
) -> Self {
// Ensure provided aggregator endpoints end with a slash, as we will be joining additional
// path segments into these endpoints & the Url::join implementation is persnickety about
// the slash at the end of the path.
for url in &mut aggregator_endpoints {
if !url.as_str().ends_with('/') {
url.set_path(&format!("{}/", url.path()));
}
}

Self {
task_id,
aggregator_endpoints,
Expand Down Expand Up @@ -237,30 +246,41 @@ mod tests {
where
for<'a> &'a V::AggregateShare: Into<Vec<u8>>,
{
let task_id = TaskId::random();

let clock = MockClock::default();
let (leader_hpke_config, _) = generate_hpke_config_and_private_key();
let (helper_hpke_config, _) = generate_hpke_config_and_private_key();

let server_url = Url::parse(&mockito::server_url()).unwrap();

let client_parameters = ClientParameters {
task_id,
aggregator_endpoints: vec![server_url.clone(), server_url],
min_batch_duration: Duration::from_seconds(1),
};

Client::new(
client_parameters,
ClientParameters::new(
TaskId::random(),
Vec::from([server_url.clone(), server_url]),
Duration::from_seconds(1),
),
vdaf_client,
clock,
MockClock::default(),
&default_http_client().unwrap(),
leader_hpke_config,
helper_hpke_config,
generate_hpke_config_and_private_key().0,
generate_hpke_config_and_private_key().0,
)
}

#[test]
fn aggregator_endpoints_end_in_slash() {
let client_parameters = ClientParameters::new(
TaskId::random(),
Vec::from([
"http://leader_endpoint/foo/bar".parse().unwrap(),
"http://helper_endpoint".parse().unwrap(),
]),
Duration::from_seconds(1),
);

assert_eq!(
client_parameters.aggregator_endpoints,
Vec::from([
"http://leader_endpoint/foo/bar/".parse().unwrap(),
"http://helper_endpoint/".parse().unwrap()
])
);
}

#[tokio::test]
async fn upload_prio3_count() {
install_test_trace_subscriber();
Expand Down
2 changes: 1 addition & 1 deletion janus_server/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2065,7 +2065,7 @@ where
const CORS_PREFLIGHT_CACHE_AGE: u32 = 24 * 60 * 60;

/// Constructs a Warp filter with endpoints common to all aggregators.
fn aggregator_filter<C: Clock>(
pub fn aggregator_filter<C: Clock>(
datastore: Arc<Datastore<C>>,
clock: C,
) -> Result<BoxedFilter<(impl Reply,)>, Error> {
Expand Down
5 changes: 1 addition & 4 deletions janus_server/src/aggregator/aggregate_share.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,7 @@ impl CollectJobDriver {

let response = self
.http_client
.post(
task.aggregator_url(Role::Helper)?
.join("/aggregate_share")?,
)
.post(task.aggregator_url(Role::Helper)?.join("aggregate_share")?)
.header(CONTENT_TYPE, AggregateShareReq::MEDIA_TYPE)
.header(
DAP_AUTH_HEADER,
Expand Down
4 changes: 2 additions & 2 deletions janus_server/src/aggregator/aggregation_job_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ impl AggregationJobDriver {

let response = self
.http_client
.post(task.aggregator_url(Role::Helper)?.join("/aggregate")?)
.post(task.aggregator_url(Role::Helper)?.join("aggregate")?)
.header(CONTENT_TYPE, AggregateInitializeReq::MEDIA_TYPE)
.header(
DAP_AUTH_HEADER,
Expand Down Expand Up @@ -504,7 +504,7 @@ impl AggregationJobDriver {

let response = self
.http_client
.post(task.aggregator_url(Role::Helper)?.join("/aggregate")?)
.post(task.aggregator_url(Role::Helper)?.join("aggregate")?)
.header(CONTENT_TYPE, AggregateContinueReq::MEDIA_TYPE)
.header(
DAP_AUTH_HEADER,
Expand Down
82 changes: 62 additions & 20 deletions janus_server/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ impl Task {
/// Create a new [`Task`] from the provided values
pub fn new<I: IntoIterator<Item = (HpkeConfig, HpkePrivateKey)>>(
task_id: TaskId,
aggregator_endpoints: Vec<Url>,
mut aggregator_endpoints: Vec<Url>,
vdaf: VdafInstance,
role: Role,
vdaf_verify_keys: Vec<Vec<u8>>,
Expand Down Expand Up @@ -239,6 +239,15 @@ impl Task {
return Err(Error::InvalidParameter("vdaf_verify_keys"));
}

// Ensure provided aggregator endpoints end with a slash, as we will be joining additional
// path segments into these endpoints & the Url::join implementation is persnickety about
// the slash at the end of the path.
for url in &mut aggregator_endpoints {
if !url.as_str().ends_with('/') {
url.set_path(&format!("{}/", url.path()));
}
}

// Compute hpke_configs mapping cfg.id -> (cfg, key).
let hpke_configs: HashMap<HpkeConfigId, (HpkeConfig, HpkePrivateKey)> = hpke_keys
.into_iter()
Expand Down Expand Up @@ -424,32 +433,28 @@ impl<'de> Deserialize<'de> for Task {
.collect::<Result<_, _>>()?;

// hpke_keys
let hpke_keys: HashMap<_, _> = serialized_task
let hpke_keys: Vec<(_, _)> = serialized_task
.hpke_keys
.into_iter()
.map(|keypair| {
Ok((
keypair.config.id,
keypair.try_into().map_err(D::Error::custom)?,
))
})
.map(|keypair| keypair.try_into().map_err(D::Error::custom))
.collect::<Result<_, _>>()?;

Ok(Task {
id: task_id,
aggregator_endpoints: serialized_task.aggregator_endpoints,
vdaf: serialized_task.vdaf,
role: serialized_task.role,
Task::new(
task_id,
serialized_task.aggregator_endpoints,
serialized_task.vdaf,
serialized_task.role,
vdaf_verify_keys,
max_batch_lifetime: serialized_task.max_batch_lifetime,
min_batch_size: serialized_task.min_batch_size,
min_batch_duration: serialized_task.min_batch_duration,
tolerable_clock_skew: serialized_task.tolerable_clock_skew,
serialized_task.max_batch_lifetime,
serialized_task.min_batch_size,
serialized_task.min_batch_duration,
serialized_task.tolerable_clock_skew,
collector_hpke_config,
aggregator_auth_tokens,
collector_auth_tokens,
hpke_keys,
})
)
.map_err(D::Error::custom)
}
}

Expand Down Expand Up @@ -602,9 +607,15 @@ pub mod test_util {

#[cfg(test)]
mod tests {
use super::test_util::new_dummy_task;
use super::{
test_util::{generate_auth_token, new_dummy_task},
Task, PRIO3_AES128_VERIFY_KEY_LENGTH,
};
use crate::{config::test_util::roundtrip_encoding, task::VdafInstance};
use janus_core::message::{Duration, Interval, Role, TaskId, Time};
use janus_core::{
hpke::test_util::generate_hpke_config_and_private_key,
message::{Duration, Interval, Role, TaskId, Time},
};
use serde_test::{assert_tokens, Token};

#[test]
Expand Down Expand Up @@ -765,4 +776,35 @@ mod tests {
Role::Leader,
));
}

#[test]
fn aggregator_endpoints_end_in_slash() {
let task = Task::new(
TaskId::random(),
Vec::from([
"http://leader_endpoint/foo/bar".parse().unwrap(),
"http://helper_endpoint".parse().unwrap(),
]),
VdafInstance::Real(janus_core::task::VdafInstance::Prio3Aes128Count),
Role::Leader,
Vec::from([[0; PRIO3_AES128_VERIFY_KEY_LENGTH].into()]),
0,
0,
Duration::from_hours(8).unwrap(),
Duration::from_minutes(10).unwrap(),
generate_hpke_config_and_private_key().0,
Vec::from([generate_auth_token()]),
Vec::from([generate_auth_token()]),
Vec::from([generate_hpke_config_and_private_key()]),
)
.unwrap();

assert_eq!(
task.aggregator_endpoints,
Vec::from([
"http://leader_endpoint/foo/bar/".parse().unwrap(),
"http://helper_endpoint/".parse().unwrap()
])
);
}
}
2 changes: 2 additions & 0 deletions monolithic_integration_test/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ janus = [
[dependencies]
backoff = "0.4"
daphne = { git = "https://github.com/cloudflare/daphne", rev = "6301e712df216a0301c42cb3177110dd8217fa84", optional = true }
futures = "0.3"
hex = "0.4"
hpke-dispatch = "0.3"
http = "0.2"
Expand All @@ -38,6 +39,7 @@ subprocess = { version = "0.2", optional = true }
tempfile = { version = "3", optional = true }
tokio = { version = "1", features = ["full", "tracing"] }
toml = "0.5"
warp = { version = "0.3", features = ["tls"] }

[dev-dependencies]
chrono = "0.4.21"
Expand Down
33 changes: 22 additions & 11 deletions monolithic_integration_test/src/janus.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
//! Functionality for tests interacting with Janus (<https://github.com/divviup/janus>).
use http::HeaderMap;
use futures::FutureExt;
use janus_core::{message::Duration, time::RealClock, TokioRuntime};
use janus_server::{
aggregator::{
aggregate_share::CollectJobDriver, aggregation_job_creator::AggregationJobCreator,
aggregation_job_driver::AggregationJobDriver, aggregator_server,
aggregation_job_driver::AggregationJobDriver, aggregator_filter,
},
binary_utils::job_driver::JobDriver,
datastore::test_util::{ephemeral_datastore, DbHandle},
Expand All @@ -18,6 +18,7 @@ use std::{
time,
};
use tokio::{select, sync::oneshot, task, try_join};
use warp::Filter;

/// Represents a running Janus test instance.
pub struct Janus {
Expand Down Expand Up @@ -48,15 +49,25 @@ impl Janus {

// Start aggregator server.
let (server_shutdown_sender, server_shutdown_receiver) = oneshot::channel();
let (_, leader_server) = aggregator_server(
Arc::clone(&datastore),
RealClock::default(),
SocketAddr::from((Ipv4Addr::LOCALHOST, port)),
HeaderMap::new(),
async move { server_shutdown_receiver.await.unwrap() },
)
.unwrap();
let server_task_handle = task::spawn(leader_server);
let aggregator_filter = task
.aggregator_url(task.role)
.unwrap()
.path_segments()
.unwrap()
.filter_map(|s| (!s.is_empty()).then(|| warp::path(s.to_owned()).boxed()))
.reduce(|x, y| x.and(y).boxed())
.unwrap_or_else(|| warp::any().boxed())
.and(aggregator_filter(datastore, RealClock::default()).unwrap());
let server = warp::serve(aggregator_filter);
let server_task_handle = task::spawn(async move {
server
.bind_with_graceful_shutdown(
SocketAddr::from((Ipv4Addr::LOCALHOST, port)),
server_shutdown_receiver.map(Result::unwrap),
)
.1
.await
});

// Start aggregation job creator.
let (
Expand Down
2 changes: 1 addition & 1 deletion monolithic_integration_test/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ pub async fn submit_measurements_and_verify_aggregate(
let collect_url = leader_task
.aggregator_url(Role::Leader)
.unwrap()
.join("/collect")
.join("collect")
.unwrap();
let batch_interval = Interval::new(
before_timestamp
Expand Down
9 changes: 8 additions & 1 deletion monolithic_integration_test/tests/janus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@ async fn janus_janus() {
// Start servers.
let (janus_leader_port, janus_helper_port) = pick_two_unused_ports();
let (collector_hpke_config, collector_private_key) = generate_hpke_config_and_private_key();
let (janus_leader_task, janus_helper_task) =
let (mut janus_leader_task, mut janus_helper_task) =
create_test_tasks(janus_leader_port, janus_helper_port, &collector_hpke_config);

// Update tasks to serve out of /dap/ prefix.
for task in [&mut janus_leader_task, &mut janus_helper_task] {
for url in &mut task.aggregator_endpoints {
url.set_path("/dap/");
}
}

let _janus_leader = Janus::new(janus_leader_port, &janus_leader_task).await;
let _janus_helper = Janus::new(janus_helper_port, &janus_helper_task).await;

Expand Down

0 comments on commit da42381

Please sign in to comment.