From 788ad1301b8fa393a2c21fb6a4e8a298eaf0e034 Mon Sep 17 00:00:00 2001 From: Stan Bondi Date: Thu, 30 Nov 2023 11:21:46 +0400 Subject: [PATCH] fix(shutdown): is_triggered returns up-to-date value without first polling --- infrastructure/shutdown/src/lib.rs | 53 ++++++++++++------- .../shutdown/src/oneshot_trigger.rs | 2 +- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/infrastructure/shutdown/src/lib.rs b/infrastructure/shutdown/src/lib.rs index 18226edcf1..01aa3d3b21 100644 --- a/infrastructure/shutdown/src/lib.rs +++ b/infrastructure/shutdown/src/lib.rs @@ -25,13 +25,12 @@ pub mod oneshot_trigger; use std::{ future::Future, pin::Pin, + sync::{atomic, atomic::AtomicBool, Arc}, task::{Context, Poll}, }; use futures::{future, future::FusedFuture}; -use crate::oneshot_trigger::OneshotSignal; - /// Trigger for shutdowns. /// /// Use `to_signal` to create a future which will resolve when `Shutdown` is triggered. @@ -40,22 +39,32 @@ use crate::oneshot_trigger::OneshotSignal; /// _Note_: This will trigger when dropped, so the `Shutdown` instance should be held as /// long as required by the application. #[derive(Clone, Debug)] -pub struct Shutdown(oneshot_trigger::OneshotTrigger<()>); +pub struct Shutdown { + trigger: oneshot_trigger::OneshotTrigger<()>, + is_triggered: Arc, +} impl Shutdown { pub fn new() -> Self { - Self(oneshot_trigger::OneshotTrigger::new()) + Self { + trigger: oneshot_trigger::OneshotTrigger::new(), + is_triggered: Arc::new(AtomicBool::new(false)), + } } pub fn trigger(&mut self) { - self.0.broadcast(()); + self.trigger.broadcast(()); + self.is_triggered.store(true, atomic::Ordering::SeqCst); } pub fn is_triggered(&self) -> bool { - self.0.is_used() + self.trigger.is_used() } pub fn to_signal(&self) -> ShutdownSignal { - self.0.to_signal().into() + ShutdownSignal { + inner: self.trigger.to_signal(), + is_triggered: self.is_triggered.clone(), + } } } @@ -67,11 +76,17 @@ impl Default for Shutdown { /// Receiver end of a shutdown signal. Once received the consumer should shut down. #[derive(Debug, Clone)] -pub struct ShutdownSignal(oneshot_trigger::OneshotSignal<()>); +pub struct ShutdownSignal { + inner: oneshot_trigger::OneshotSignal<()>, + is_triggered: Arc, +} impl ShutdownSignal { pub fn is_triggered(&self) -> bool { - self.0.is_terminated() + // Shared future in OneshotTrigger requires a poll before is_terminated returns true. + // For our use case here, we expect is_triggered to return true _immediately_ as the trigger is fired without + // first polling the signal. To this end, we use an AtomicBool to track the triggered state. + self.is_triggered.load(atomic::Ordering::SeqCst) } /// Wait for the shutdown signal to trigger. @@ -88,7 +103,7 @@ impl Future for ShutdownSignal { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match Pin::new(&mut self.0).poll(cx) { + match Pin::new(&mut self.inner).poll(cx) { // Whether `trigger()` was called Some(()), or the Shutdown dropped (None) we want to resolve this future Poll::Ready(_) => Poll::Ready(()), Poll::Pending => Poll::Pending, @@ -98,13 +113,7 @@ impl Future for ShutdownSignal { impl FusedFuture for ShutdownSignal { fn is_terminated(&self) -> bool { - self.0.is_terminated() - } -} - -impl From> for ShutdownSignal { - fn from(inner: OneshotSignal<()>) -> Self { - Self(inner) + self.is_triggered() } } @@ -167,6 +176,7 @@ impl FusedFuture for OptionalShutdownSignal { #[cfg(test)] mod test { + use tokio::task; use super::*; @@ -191,12 +201,15 @@ mod test { async fn signal_clone() { let mut shutdown = Shutdown::new(); let signal = shutdown.to_signal(); - let signal_clone = signal.clone(); + let mut signal_clone = signal.clone(); let fut = task::spawn(async move { - signal_clone.await; - signal.await; + signal_clone.wait().await; + assert!(signal_clone.is_triggered()); }); + assert!(!signal.is_triggered()); shutdown.trigger(); + assert!(signal.is_triggered()); + assert!(shutdown.is_triggered()); fut.await.unwrap(); } diff --git a/infrastructure/shutdown/src/oneshot_trigger.rs b/infrastructure/shutdown/src/oneshot_trigger.rs index 4b47943e9c..47afafbe03 100644 --- a/infrastructure/shutdown/src/oneshot_trigger.rs +++ b/infrastructure/shutdown/src/oneshot_trigger.rs @@ -90,7 +90,7 @@ impl Future for OneshotSignal { type Output = Option; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.inner.is_terminated() { + if self.is_terminated() { return Poll::Ready(None); }