Skip to content

Commit

Permalink
h2-support,h2-tests: add tools to ensure wake
Browse files Browse the repository at this point in the history
This commit adds wrappers around futures::future helpers and augments
TestFuture to ensure that the underlying futures are notified before
they are polled. This helps to catch bugs where there are missing notify
calls or bad handling of the waker.

The commit then extends the tests to use these helpers instead of the
library functions from futures.

It also ammends the client_requests::recv_too_big_headers test to no
longer use the tokio spawned tasks that were added in hyperium#791.
  • Loading branch information
ajwerner committed Aug 6, 2024
1 parent 0f10650 commit a4468f0
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 38 deletions.
141 changes: 137 additions & 4 deletions tests/h2-support/src/future_ext.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use futures::FutureExt;
use futures::{FutureExt, TryFuture};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::task::{Context, Poll, Wake, Waker};

/// Future extension helpers that are useful for tests
pub trait TestFuture: Future {
Expand All @@ -15,9 +17,140 @@ pub trait TestFuture: Future {
{
Drive {
driver: self,
future: Box::pin(other),
future: other.wakened(),
}
}

fn wakened(self) -> Wakened<Self>
where
Self: Sized,
{
Wakened {
future: Box::pin(self),
woken: Arc::new(AtomicBool::new(true)),
}
}
}

/// Wraps futures::future::join to ensure that the futures are only polled if they are woken.
pub fn join<Fut1, Fut2>(
future1: Fut1,
future2: Fut2,
) -> futures::future::Join<Wakened<Fut1>, Wakened<Fut2>>
where
Fut1: Future,
Fut2: Future,
{
futures::future::join(future1.wakened(), future2.wakened())
}

/// Wraps futures::future::join3 to ensure that the futures are only polled if they are woken.
pub fn join3<Fut1, Fut2, Fut3>(
future1: Fut1,
future2: Fut2,
future3: Fut3,
) -> futures::future::Join3<Wakened<Fut1>, Wakened<Fut2>, Wakened<Fut3>>
where
Fut1: Future,
Fut2: Future,
Fut3: Future,
{
futures::future::join3(future1.wakened(), future2.wakened(), future3.wakened())
}

/// Wraps futures::future::join4 to ensure that the futures are only polled if they are woken.
pub fn join4<Fut1, Fut2, Fut3, Fut4>(
future1: Fut1,
future2: Fut2,
future3: Fut3,
future4: Fut4,
) -> futures::future::Join4<Wakened<Fut1>, Wakened<Fut2>, Wakened<Fut3>, Wakened<Fut4>>
where
Fut1: Future,
Fut2: Future,
Fut3: Future,
Fut4: Future,
{
futures::future::join4(
future1.wakened(),
future2.wakened(),
future3.wakened(),
future4.wakened(),
)
}

/// Wraps futures::future::try_join to ensure that the futures are only polled if they are woken.
pub fn try_join<Fut1, Fut2>(
future1: Fut1,
future2: Fut2,
) -> futures::future::TryJoin<Wakened<Fut1>, Wakened<Fut2>>
where
Fut1: futures::future::TryFuture + Future,
Fut2: Future,
Wakened<Fut1>: futures::future::TryFuture,
Wakened<Fut2>: futures::future::TryFuture<Error = <Wakened<Fut1> as TryFuture>::Error>,
{
futures::future::try_join(future1.wakened(), future2.wakened())
}

/// Wraps futures::future::select to ensure that the futures are only polled if they are woken.
pub fn select<A, B>(future1: A, future2: B) -> futures::future::Select<Wakened<A>, Wakened<B>>
where
A: Future + Unpin,
B: Future + Unpin,
{
futures::future::select(future1.wakened(), future2.wakened())
}

/// Wraps futures::future::join_all to ensure that the futures are only polled if they are woken.
pub fn join_all<I>(iter: I) -> futures::future::JoinAll<Wakened<I::Item>>
where
I: IntoIterator,
I::Item: Future,
{
futures::future::join_all(iter.into_iter().map(|f| f.wakened()))
}

/// A future that only polls the inner future if it has been woken (after the initial poll).
pub struct Wakened<T> {
future: Pin<Box<T>>,
woken: Arc<AtomicBool>,
}

/// A future that only polls the inner future if it has been woken (after the initial poll).
impl<T> Future for Wakened<T>
where
T: Future,
{
type Output = T::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if !this.woken.load(std::sync::atomic::Ordering::SeqCst) {
return Poll::Pending;
}
this.woken.store(false, std::sync::atomic::Ordering::SeqCst);
let my_waker = IfWokenWaker {
inner: cx.waker().clone(),
wakened: this.woken.clone(),
};
let my_waker = Arc::new(my_waker).into();
let mut cx = Context::from_waker(&my_waker);
this.future.as_mut().poll(&mut cx)
}
}

impl Wake for IfWokenWaker {
fn wake(self: Arc<Self>) {
self.wakened
.store(true, std::sync::atomic::Ordering::SeqCst);
self.inner.wake_by_ref();
}
}

struct IfWokenWaker {
inner: Waker,
wakened: Arc<AtomicBool>,
}

impl<T: Future> TestFuture for T {}
Expand All @@ -29,7 +162,7 @@ impl<T: Future> TestFuture for T {}
/// This is useful for H2 futures that also require the connection to be polled.
pub struct Drive<'a, T, U> {
driver: &'a mut T,
future: Pin<Box<U>>,
future: Wakened<U>,
}

impl<'a, T, U> Future for Drive<'a, T, U>
Expand Down
2 changes: 1 addition & 1 deletion tests/h2-support/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub use {bytes, futures, http, tokio::io as tokio_io, tracing, tracing_subscribe
pub use futures::{Future, Sink, Stream};

// And our Future extensions
pub use super::future_ext::TestFuture;
pub use super::future_ext::{join, join3, join4, join_all, select, try_join, TestFuture};

// Our client_ext helpers
pub use super::client_ext::SendRequestExt;
Expand Down
21 changes: 7 additions & 14 deletions tests/h2-tests/tests/client_request.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use futures::future::{join, join_all, ready, select, Either};
use futures::future::{ready, Either};
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use h2_support::prelude::*;
Expand Down Expand Up @@ -849,7 +849,7 @@ async fn recv_too_big_headers() {
};

let client = async move {
let (mut client, conn) = client::Builder::new()
let (mut client, mut conn) = client::Builder::new()
.max_header_list_size(10)
.handshake::<_, Bytes>(io)
.await
Expand All @@ -863,30 +863,23 @@ async fn recv_too_big_headers() {
let req1 = client.send_request(request, true);
// Spawn tasks to ensure that the error wakes up tasks that are blocked
// waiting for a response.
let req1 = tokio::spawn(async move {
let req1 = async move {
let err = req1.expect("send_request").0.await.expect_err("response1");
assert_eq!(err.reason(), Some(Reason::REFUSED_STREAM));
});
};

let request = Request::builder()
.uri("https://http2.akamai.com/")
.body(())
.unwrap();

let req2 = client.send_request(request, true);
let req2 = tokio::spawn(async move {
let req2 = async move {
let err = req2.expect("send_request").0.await.expect_err("response2");
assert_eq!(err.reason(), Some(Reason::REFUSED_STREAM));
});
};

let conn = tokio::spawn(async move {
conn.await.expect("client");
});
for err in join_all([req1, req2, conn]).await {
if let Some(err) = err.err().and_then(|err| err.try_into_panic().ok()) {
std::panic::resume_unwind(err);
}
}
conn.drive(join(req1, req2)).await;
};

join(srv, client).await;
Expand Down
1 change: 0 additions & 1 deletion tests/h2-tests/tests/codec_read.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use futures::future::join;
use h2_support::prelude::*;

#[tokio::test]
Expand Down
1 change: 0 additions & 1 deletion tests/h2-tests/tests/codec_write.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use futures::future::join;
use h2_support::prelude::*;

#[tokio::test]
Expand Down
1 change: 0 additions & 1 deletion tests/h2-tests/tests/flow_control.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use futures::future::{join, join4};
use futures::{StreamExt, TryStreamExt};
use h2_support::prelude::*;
use h2_support::util::yield_once;
Expand Down
1 change: 0 additions & 1 deletion tests/h2-tests/tests/ping_pong.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use futures::channel::oneshot;
use futures::future::join;
use futures::StreamExt;
use h2_support::assert_ping;
use h2_support::prelude::*;
Expand Down
1 change: 0 additions & 1 deletion tests/h2-tests/tests/prioritization.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use futures::future::{join, select};
use futures::{pin_mut, FutureExt, StreamExt};

use h2_support::prelude::*;
Expand Down
16 changes: 4 additions & 12 deletions tests/h2-tests/tests/push_promise.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use std::iter::FromIterator;

use futures::{future::join, FutureExt as _, StreamExt, TryStreamExt};
use futures::{StreamExt, TryStreamExt};
use h2_support::prelude::*;

#[tokio::test]
Expand Down Expand Up @@ -52,15 +50,9 @@ async fn recv_push_works() {
let ps: Vec<_> = p.collect().await;
assert_eq!(1, ps.len())
};
// Use a FuturesUnordered to poll both tasks but only poll them
// if they have been notified.
let tasks = futures::stream::FuturesUnordered::from_iter([
check_resp_status.boxed(),
check_pushed_response.boxed(),
])
.collect::<()>();

h2.drive(tasks).await;

h2.drive(join(check_resp_status, check_pushed_response))
.await;
};

join(mock, h2).await;
Expand Down
1 change: 0 additions & 1 deletion tests/h2-tests/tests/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#![deny(warnings)]

use futures::future::join;
use futures::StreamExt;
use h2_support::prelude::*;
use tokio::io::AsyncWriteExt;
Expand Down
2 changes: 1 addition & 1 deletion tests/h2-tests/tests/stream_states.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![deny(warnings)]

use futures::future::{join, join3, lazy, try_join};
use futures::future::lazy;
use futures::{FutureExt, StreamExt, TryStreamExt};
use h2_support::prelude::*;
use h2_support::util::yield_once;
Expand Down

0 comments on commit a4468f0

Please sign in to comment.