Skip to content

Commit

Permalink
polish AsyncWaitGroup
Browse files Browse the repository at this point in the history
  • Loading branch information
al8n committed Feb 10, 2024
1 parent ffaca74 commit 01a4214
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 53 deletions.
13 changes: 9 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ homepage = "https://github.com/al8n/wg"
repository = "https://github.com/al8n/wg.git"
documentation = "https://docs.rs/wg/"
readme = "README.md"
version = "0.5.0"
version = "0.6.0"
license = "MIT OR Apache-2.0"
keywords = ["waitgroup", "async", "sync", "notify", "wake"]
categories = ["asynchronous", "concurrency", "data-structures"]
Expand All @@ -18,14 +18,19 @@ full = ["triomphe", "parking_lot"]
triomphe = ["dep:triomphe"]
parking_lot = ["dep:parking_lot"]

future = ["event-listener", "event-listener-strategy", "pin-project-lite"]

[dependencies]
parking_lot = {version = "0.12", optional = true }
parking_lot = { version = "0.12", optional = true }
triomphe = { version = "0.1", optional = true }
event-listener = { version = "5", optional = true }
event-listener-strategy = { version = "0.5", optional = true }
pin-project-lite = { version = "0.2", optional = true }

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
async-std = { version = "1.12", features = ["attributes"] }
async-std = { version = "1.12", features = ["attributes"] }

[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]
rustdoc-args = ["--cfg", "docsrs"]
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,25 @@ Golang like WaitGroup implementation for sync/async Rust.
</div>

## Installation

By default, blocking version `WaitGroup` is enabled, if you want to use non-blocking `AsyncWaitGroup`, you need to
enbale `future` feature in your `Cargo.toml`.

### Sync

```toml
[dependencies]
wg = "0.5"
wg = "0.6"
```

### Async

```toml
[dependencies]
wg = { version: "0.6", features = ["future"] }
```


## Example

### Sync
Expand Down
122 changes: 74 additions & 48 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,21 +371,25 @@ impl WaitGroup {
}
}

#[cfg(feature = "future")]
pub use r#async::*;

#[cfg(feature = "future")]
mod r#async {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use event_listener::{Event, EventListener};
use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy};

use std::{
future::Future,
pin::Pin,
task::{Context, Poll, Waker},
sync::atomic::{AtomicUsize, Ordering},
task::Poll,
};

#[derive(Debug)]
struct AsyncInner {
waker: Mutex<Option<Waker>>,
count: AtomicUsize,
counter: AtomicUsize,
event: Event,
}

/// An AsyncWaitGroup waits for a collection of threads to finish.
Expand Down Expand Up @@ -429,7 +433,7 @@ mod r#async {
///
/// [`wait`]: struct.AsyncWaitGroup.html#method.wait
/// [`add`]: struct.AsyncWaitGroup.html#method.add
#[cfg_attr(docsrs, doc(cfg(feature = "test")))]
#[cfg_attr(docsrs, doc(cfg(feature = "future")))]
pub struct AsyncWaitGroup {
inner: Arc<AsyncInner>,
}
Expand All @@ -438,8 +442,8 @@ mod r#async {
fn default() -> Self {
Self {
inner: Arc::new(AsyncInner {
count: AtomicUsize::new(0),
waker: Mutex::new(None),
counter: AtomicUsize::new(0),
event: Event::new(),
}),
}
}
Expand All @@ -449,8 +453,8 @@ mod r#async {
fn from(count: usize) -> Self {
Self {
inner: Arc::new(AsyncInner {
count: AtomicUsize::new(count),
waker: Mutex::new(None),
counter: AtomicUsize::new(count),
event: Event::new(),
}),
}
}
Expand All @@ -466,10 +470,8 @@ mod r#async {

impl std::fmt::Debug for AsyncWaitGroup {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let count = self.inner.count.load(Ordering::Relaxed);

f.debug_struct("AsyncWaitGroup")
.field("count", &count)
.field("counter", &self.inner.counter)
.finish()
}
}
Expand Down Expand Up @@ -513,7 +515,7 @@ mod r#async {
///
/// [`wait`]: struct.AsyncWaitGroup.html#method.wait
pub fn add(&self, num: usize) -> Self {
self.inner.count.fetch_add(num, Ordering::SeqCst);
self.inner.counter.fetch_add(num, Ordering::AcqRel);

Self {
inner: self.inner.clone(),
Expand All @@ -539,18 +541,14 @@ mod r#async {
/// }
/// ```
pub fn done(&self) {
let count = self.inner.count.fetch_sub(1, Ordering::Relaxed);
// We are the last worker
if count == 1 {
if let Some(waker) = self.inner.waker.lock_me().take() {
waker.wake();
}
if self.inner.counter.fetch_sub(1, Ordering::SeqCst) == 1 {
self.inner.event.notify(usize::MAX);
}
}

/// waitings return how many jobs are waiting.
pub fn waitings(&self) -> usize {
self.inner.count.load(Ordering::Acquire)
self.inner.counter.load(Ordering::Acquire)
}

/// wait blocks until the [`AsyncWaitGroup`] counter is zero.
Expand All @@ -575,8 +573,8 @@ mod r#async {
/// wg.wait().await;
/// }
/// ```
pub async fn wait(&self) {
WaitGroupFuture::new(&self.inner).await
pub fn wait(&self) -> WaitGroupFuture<'_> {
WaitGroupFuture::_new(WaitGroupFutureInner::new(&self.inner))
}

/// Wait blocks until the [`AsyncWaitGroup`] counter is zero. This method is
Expand Down Expand Up @@ -606,39 +604,67 @@ mod r#async {
/// }
/// ```
pub fn block_wait(&self) {
loop {
match self.inner.count.load(Ordering::Acquire) {
0 => return,
_ => core::hint::spin_loop(),
}
}
WaitGroupFutureInner::new(&self.inner).wait();
}
}

struct WaitGroupFuture<'a> {
inner: &'a Arc<AsyncInner>,
easy_wrapper! {
/// A future returned by [`AsyncWaitGroup::wait()`].
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[cfg_attr(docsrs, doc(cfg(feature = "future")))]
pub struct WaitGroupFuture<'a>(WaitGroupFutureInner<'a> => ());

#[cfg(all(feature = "std", not(target_family = "wasm")))]
pub(crate) wait();
}

impl<'a> WaitGroupFuture<'a> {
pin_project_lite::pin_project! {
/// A future that used to wait for the [`AsyncWaitGroup`] counter is zero.
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[project(!Unpin)]
#[derive(Debug)]
struct WaitGroupFutureInner<'a> {
inner: &'a Arc<AsyncInner>,
listener: Option<EventListener>,
#[pin]
_pin: std::marker::PhantomPinned,
}
}

impl<'a> WaitGroupFutureInner<'a> {
fn new(inner: &'a Arc<AsyncInner>) -> Self {
Self { inner }
Self {
inner,
listener: None,
_pin: std::marker::PhantomPinned,
}
}
}

impl Future for WaitGroupFuture<'_> {
impl EventListenerFuture for WaitGroupFutureInner<'_> {
type Output = ();

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.inner.count.load(Ordering::Relaxed) == 0 {
return Poll::Ready(());
}

let waker = cx.waker().clone();
*self.inner.waker.lock_me() = Some(waker);
fn poll_with_strategy<'a, S: Strategy<'a>>(
self: Pin<&mut Self>,
strategy: &mut S,
context: &mut S::Context,
) -> Poll<Self::Output> {
let this = self.project();
loop {
if this.inner.counter.load(Ordering::Acquire) == 0 {
return Poll::Ready(());
}

match self.inner.count.load(Ordering::Relaxed) {
0 => Poll::Ready(()),
_ => Poll::Pending,
if this.listener.is_some() {
// Poll using the given strategy
match S::poll(strategy, &mut *this.listener, context) {
Poll::Ready(_) => {}
Poll::Pending => return Poll::Pending,
}
} else {
*this.listener = Some(this.inner.event.listen());
}
}
}
}
Expand Down Expand Up @@ -759,13 +785,13 @@ mod r#async {
assert_eq!(wg.waitings(), 2);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_async_block_wait() {
#[test]
fn test_async_block_wait() {
let wg = AsyncWaitGroup::new();
let t_wg = wg.add(1);
tokio::spawn(async move {
std::thread::spawn(move || {
// do some time consuming task
t_wg.done()
t_wg.done();
});

// wait other thread completes
Expand Down

0 comments on commit 01a4214

Please sign in to comment.