Skip to content

Commit

Permalink
fix(shutdown): is_triggered returns up-to-date value without first po…
Browse files Browse the repository at this point in the history
…lling
  • Loading branch information
sdbondi committed Nov 30, 2023
1 parent 6723dc7 commit 3e9dfaf
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 21 deletions.
53 changes: 33 additions & 20 deletions infrastructure/shutdown/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@ pub mod oneshot_trigger;
use std::{
future::Future,
pin::Pin,
sync::{atomic, atomic::AtomicBool, Arc},
task::{Context, Poll},
};

use futures::{future, future::FusedFuture};

use crate::oneshot_trigger::OneshotSignal;

/// Trigger for shutdowns.
///
/// Use `to_signal` to create a future which will resolve when `Shutdown` is triggered.
Expand All @@ -40,22 +39,32 @@ use crate::oneshot_trigger::OneshotSignal;
/// _Note_: This will trigger when dropped, so the `Shutdown` instance should be held as
/// long as required by the application.
#[derive(Clone, Debug)]
pub struct Shutdown(oneshot_trigger::OneshotTrigger<()>);
pub struct Shutdown {
trigger: oneshot_trigger::OneshotTrigger<()>,
is_triggered: Arc<AtomicBool>,
}
impl Shutdown {
pub fn new() -> Self {
Self(oneshot_trigger::OneshotTrigger::new())
Self {
trigger: oneshot_trigger::OneshotTrigger::new(),
is_triggered: Arc::new(AtomicBool::new(false)),
}
}

pub fn trigger(&mut self) {
self.0.broadcast(());
self.trigger.broadcast(());
self.is_triggered.store(true, atomic::Ordering::SeqCst);
}

pub fn is_triggered(&self) -> bool {
self.0.is_used()
self.trigger.is_used()
}

pub fn to_signal(&self) -> ShutdownSignal {
self.0.to_signal().into()
ShutdownSignal {
inner: self.trigger.to_signal().into(),
is_triggered: self.is_triggered.clone(),
}
}
}

Expand All @@ -67,11 +76,17 @@ impl Default for Shutdown {

/// Receiver end of a shutdown signal. Once received the consumer should shut down.
#[derive(Debug, Clone)]
pub struct ShutdownSignal(oneshot_trigger::OneshotSignal<()>);
pub struct ShutdownSignal {
inner: oneshot_trigger::OneshotSignal<()>,
is_triggered: Arc<AtomicBool>,
}

impl ShutdownSignal {
pub fn is_triggered(&self) -> bool {
self.0.is_terminated()
// Shared future in OneshotTrigger requires a poll before is_terminated returns true.
// For our use case here, we expect is_triggered to return true _immediately_ as the trigger is fired without
// first polling the signal. To this end, we use an AtomicBool to track the triggered state.
self.is_triggered.load(atomic::Ordering::SeqCst)
}

/// Wait for the shutdown signal to trigger.
Expand All @@ -88,7 +103,7 @@ impl Future for ShutdownSignal {
type Output = ();

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.0).poll(cx) {
match Pin::new(&mut self.inner).poll(cx) {
// Whether `trigger()` was called Some(()), or the Shutdown dropped (None) we want to resolve this future
Poll::Ready(_) => Poll::Ready(()),
Poll::Pending => Poll::Pending,
Expand All @@ -98,13 +113,7 @@ impl Future for ShutdownSignal {

impl FusedFuture for ShutdownSignal {
fn is_terminated(&self) -> bool {
self.0.is_terminated()
}
}

impl From<oneshot_trigger::OneshotSignal<()>> for ShutdownSignal {
fn from(inner: OneshotSignal<()>) -> Self {
Self(inner)
self.is_triggered()
}
}

Expand Down Expand Up @@ -167,6 +176,7 @@ impl FusedFuture for OptionalShutdownSignal {

#[cfg(test)]
mod test {

use tokio::task;

use super::*;
Expand All @@ -191,12 +201,15 @@ mod test {
async fn signal_clone() {
let mut shutdown = Shutdown::new();
let signal = shutdown.to_signal();
let signal_clone = signal.clone();
let mut signal_clone = signal.clone();
let fut = task::spawn(async move {
signal_clone.await;
signal.await;
signal_clone.wait().await;
assert!(signal_clone.is_triggered());
});
assert!(!signal.is_triggered());
shutdown.trigger();
assert!(signal.is_triggered());
assert!(shutdown.is_triggered());
fut.await.unwrap();
}

Expand Down
2 changes: 1 addition & 1 deletion infrastructure/shutdown/src/oneshot_trigger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl<T: Clone> Future for OneshotSignal<T> {
type Output = Option<T>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.inner.is_terminated() {
if self.is_terminated() {
return Poll::Ready(None);
}

Expand Down

0 comments on commit 3e9dfaf

Please sign in to comment.