diff --git a/Cargo.lock b/Cargo.lock index 5ac46bcb4..add539a3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5006,7 +5006,7 @@ dependencies = [ [[package]] name = "rover" -version = "0.26.3" +version = "0.27.0-preview.1" dependencies = [ "anyhow", "apollo-federation-types", diff --git a/Cargo.toml b/Cargo.toml index 365e401aa..5542d2369 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ license-file = "./LICENSE" name = "rover" readme = "README.md" repository = "https://github.com/apollographql/rover/" -version = "0.26.3" +version = "0.27.0-preview.1" default-run = "rover" publish = false diff --git a/crates/rover-std/src/fs.rs b/crates/rover-std/src/fs.rs index a4e9df971..c47da59a0 100644 --- a/crates/rover-std/src/fs.rs +++ b/crates/rover-std/src/fs.rs @@ -261,8 +261,9 @@ impl Fs { pub fn watch_file( path: PathBuf, tx: UnboundedSender>, + cancellation_token: Option, ) -> CancellationToken { - let cancellation_token = CancellationToken::new(); + let cancellation_token = cancellation_token.unwrap_or_default(); let poll_watcher = PollWatcher::new( { @@ -452,7 +453,7 @@ mod tests { let path = file.path().to_path_buf(); let (tx, rx) = unbounded_channel(); let rx = Arc::new(Mutex::new(rx)); - let cancellation_token = Fs::watch_file(path.clone(), tx); + let cancellation_token = Fs::watch_file(path.clone(), tx, None); sleep(Duration::from_millis(1500)).await; @@ -513,7 +514,7 @@ mod tests { let (tx, rx) = unbounded_channel(); let rx = Arc::new(Mutex::new(rx)); - let _cancellation_token = Fs::watch_file(path.clone(), tx); + let _cancellation_token = Fs::watch_file(path.clone(), tx, None); sleep(Duration::from_millis(1500)).await; diff --git a/src/command/dev/legacy/router/config.rs b/src/command/dev/legacy/router/config.rs index 7f0305a09..cd6b9a386 100644 --- a/src/command/dev/legacy/router/config.rs +++ b/src/command/dev/legacy/router/config.rs @@ -7,9 +7,8 @@ use std::{ use anyhow::{anyhow, Context}; use camino::Utf8PathBuf; use crossbeam_channel::{unbounded, Receiver}; -use serde_json::json; - use rover_std::{warnln, Fs}; +use serde_json::json; use crate::utils::expansion::expand; use crate::{ @@ -269,7 +268,7 @@ impl RouterConfigReader { let (raw_tx, mut raw_rx) = tokio::sync::mpsc::unbounded_channel(); let (state_tx, state_rx) = unbounded(); let input_config_path: PathBuf = input_config_path.as_path().into(); - Fs::watch_file(input_config_path, raw_tx); + Fs::watch_file(input_config_path, raw_tx, None); tokio::spawn(async move { while let Some(res) = raw_rx.recv().await { res.expect("could not watch router configuration file"); diff --git a/src/command/dev/legacy/watcher.rs b/src/command/dev/legacy/watcher.rs index a9a800ef9..01df79988 100644 --- a/src/command/dev/legacy/watcher.rs +++ b/src/command/dev/legacy/watcher.rs @@ -6,14 +6,13 @@ use anyhow::{anyhow, Context}; use apollo_federation_types::javascript::SubgraphDefinition; use camino::{Utf8Path, Utf8PathBuf}; use reqwest::Client; -use tokio::time::MissedTickBehavior::Delay; -use url::Url; - use rover_client::blocking::StudioClient; use rover_client::operations::subgraph::fetch; use rover_client::operations::subgraph::fetch::SubgraphFetchInput; use rover_client::shared::GraphRef; use rover_std::{errln, Fs}; +use tokio::time::MissedTickBehavior::Delay; +use url::Url; use crate::{ command::dev::legacy::{ @@ -305,7 +304,7 @@ impl SubgraphSchemaWatcher { let watch_path: PathBuf = path.as_path().into(); - Fs::watch_file(watch_path, tx); + Fs::watch_file(watch_path, tx, None); while let Some(res) = rx.recv().await { match res { diff --git a/src/command/dev/next/router/watchers/file.rs b/src/command/dev/next/router/watchers/file.rs index 297dfbbd5..3f97422a4 100644 --- a/src/command/dev/next/router/watchers/file.rs +++ b/src/command/dev/next/router/watchers/file.rs @@ -30,7 +30,7 @@ impl FileWatcher { let path = self.path; let (file_tx, file_rx) = unbounded_channel(); let output = UnboundedReceiverStream::new(file_rx); - let cancellation_token = Fs::watch_file(path.as_path().into(), file_tx); + let cancellation_token = Fs::watch_file(path.as_path().into(), file_tx, None); output .filter_map(move |result| { @@ -61,13 +61,14 @@ impl FileWatcher { #[cfg(test)] mod tests { - use futures::StreamExt; - use speculoos::assert_that; - use speculoos::option::OptionAssertions; use std::fs::OpenOptions; use std::io::Write; use std::time::Duration; + use futures::StreamExt; + use speculoos::assert_that; + use speculoos::option::OptionAssertions; + use super::*; #[tokio::test] diff --git a/src/composition/runner/mod.rs b/src/composition/runner/mod.rs index ba756b58b..a92ad31f1 100644 --- a/src/composition/runner/mod.rs +++ b/src/composition/runner/mod.rs @@ -11,6 +11,7 @@ use std::{ use camino::Utf8PathBuf; use futures::stream::{select, BoxStream, StreamExt}; use rover_http::HttpService; +use tokio_util::sync::CancellationToken; use tower::ServiceExt; use self::state::SetupSubgraphWatchers; @@ -220,7 +221,10 @@ where // events in order to trigger recomposition. let (composition_messages, composition_subtask) = Subtask::new(self.state.composition_watcher); - composition_subtask.run(select(subgraph_change_stream, federation_watcher_stream).boxed()); + composition_subtask.run( + select(subgraph_change_stream, federation_watcher_stream).boxed(), + None, + ); // Start subgraph watchers, listening for events from the supergraph change stream. subgraph_watcher_subtask.run( @@ -235,6 +239,7 @@ where } }) .boxed(), + Some(CancellationToken::new()), ); federation_watcher_subtask.run( @@ -249,11 +254,12 @@ where } }) .boxed(), + None, ); // Start the supergraph watcher subtask. if let Some(supergraph_config_subtask) = supergraph_config_subtask { - supergraph_config_subtask.run(); + supergraph_config_subtask.run(None); } composition_messages.boxed() diff --git a/src/composition/watchers/composition.rs b/src/composition/watchers/composition.rs index 50eac955e..2042c0c31 100644 --- a/src/composition/watchers/composition.rs +++ b/src/composition/watchers/composition.rs @@ -4,8 +4,9 @@ use camino::Utf8PathBuf; use futures::stream::BoxStream; use rover_std::{errln, infoln, warnln}; use tap::TapFallible; -use tokio::{sync::mpsc::UnboundedSender, task::AbortHandle}; +use tokio::sync::mpsc::UnboundedSender; use tokio_stream::StreamExt; +use tokio_util::sync::CancellationToken; use tracing::error; use crate::composition::supergraph::install::InstallSupergraph; @@ -68,38 +69,39 @@ where mut self, sender: UnboundedSender, mut input: BoxStream<'static, Self::Input>, - ) -> AbortHandle { - tokio::task::spawn({ + cancellation_token: Option, + ) { + tokio::task::spawn(async move { let mut supergraph_config = self.supergraph_config.clone(); let target_file = self.temp_dir.join("supergraph.yaml"); - async move { - if self.compose_on_initialisation { - if let Err(err) = self - .setup_temporary_supergraph_yaml(&supergraph_config, &target_file) - .await - { - error!("Could not setup initial supergraph schema: {}", err); - }; - let _ = sender - .send(CompositionEvent::Started) - .tap_err(|err| error!("{:?}", err)); - let output = self - .run_composition(&target_file, &self.output_target) - .await; - match output { - Ok(success) => { - let _ = sender - .send(CompositionEvent::Success(success)) - .tap_err(|err| error!("{:?}", err)); - } - Err(err) => { - let _ = sender - .send(CompositionEvent::Error(err)) - .tap_err(|err| error!("{:?}", err)); - } + if self.compose_on_initialisation { + if let Err(err) = self + .setup_temporary_supergraph_yaml(&supergraph_config, &target_file) + .await + { + error!("Could not setup initial supergraph schema: {}", err); + }; + let _ = sender + .send(CompositionEvent::Started) + .tap_err(|err| error!("{:?}", err)); + let output = self + .run_composition(&target_file, &self.output_target) + .await; + match output { + Ok(success) => { + let _ = sender + .send(CompositionEvent::Success(success)) + .tap_err(|err| error!("{:?}", err)); + } + Err(err) => { + let _ = sender + .send(CompositionEvent::Error(err)) + .tap_err(|err| error!("{:?}", err)); } } - + } + let cancellation_token = cancellation_token.unwrap_or_default(); + cancellation_token.run_until_cancelled(async { while let Some(event) = input.next().await { match event { Subgraph(SubgraphEvent::SubgraphChanged(subgraph_schema_changed)) => { @@ -193,9 +195,8 @@ where } } } - } - }) - .abort_handle() + }).await; + }); } } @@ -273,6 +274,7 @@ mod tests { use rstest::rstest; use semver::Version; use speculoos::prelude::*; + use tokio_util::sync::CancellationToken; use tracing_test::traced_test; use super::{CompositionInputEvent, CompositionWatcher}; @@ -383,7 +385,8 @@ mod tests { }) .boxed(); let (mut composition_messages, composition_subtask) = Subtask::new(composition_handler); - let abort_handle = composition_subtask.run(subgraph_change_events); + let cancellation_token = CancellationToken::new(); + composition_subtask.run(subgraph_change_events, Some(cancellation_token.clone())); // Assert we always get a subgraph added event. let next_message = composition_messages.next().await; @@ -420,7 +423,7 @@ mod tests { )); } - abort_handle.abort(); + cancellation_token.cancel().await; Ok(()) } } diff --git a/src/composition/watchers/federation.rs b/src/composition/watchers/federation.rs index e8111cae5..1dacfef93 100644 --- a/src/composition/watchers/federation.rs +++ b/src/composition/watchers/federation.rs @@ -3,7 +3,7 @@ use futures::stream::BoxStream; use futures::StreamExt; use tap::TapFallible; use tokio::sync::mpsc::UnboundedSender; -use tokio::task::AbortHandle; +use tokio_util::sync::CancellationToken; use tracing::error; use crate::composition::events::CompositionEvent; @@ -25,32 +25,42 @@ impl SubtaskHandleStream for FederationWatcher { self, sender: UnboundedSender, mut input: BoxStream<'static, Self::Input>, - ) -> AbortHandle { - tokio::task::spawn(async move { - while let Some(recv_res) = input.next().await { - match recv_res { - Ok(diff) => { - if let Some(fed_version) = diff.federation_version() { - let _ = sender - .send(CompositionInputEvent::Federation( - fed_version.clone().unwrap_or(LatestFedTwo), - )) - .tap_err(|err| error!("{:?}", err)); + cancellation_token: Option, + ) { + let cancellation_token = cancellation_token.unwrap_or_default(); + tokio::spawn(async move { + let cancellation_token = cancellation_token.clone(); + cancellation_token + .run_until_cancelled(async move { + while let Some(recv_res) = input.next().await { + match recv_res { + Ok(diff) => { + if let Some(fed_version) = diff.federation_version() { + let _ = sender + .send(CompositionInputEvent::Federation( + fed_version.clone().unwrap_or(LatestFedTwo), + )) + .tap_err(|err| error!("{:?}", err)); + } + } + Err(SupergraphConfigSerialisationError::DeserializingConfigError { + source, + }) => { + let _ = sender + .send(CompositionInputEvent::Passthrough( + CompositionEvent::Error( + CompositionError::InvalidSupergraphConfig( + source.message(), + ), + ), + )) + .tap_err(|err| error!("{:?}", err)); + } + Err(_) => {} } } - Err(SupergraphConfigSerialisationError::DeserializingConfigError { - source, - }) => { - let _ = sender - .send(CompositionInputEvent::Passthrough(CompositionEvent::Error( - CompositionError::InvalidSupergraphConfig(source.message()), - ))) - .tap_err(|err| error!("{:?}", err)); - } - Err(_) => {} - } - } - }) - .abort_handle() + }) + .await; + }); } } diff --git a/src/composition/watchers/subgraphs.rs b/src/composition/watchers/subgraphs.rs index fda835b02..f973e95d7 100644 --- a/src/composition/watchers/subgraphs.rs +++ b/src/composition/watchers/subgraphs.rs @@ -6,7 +6,8 @@ use futures::stream::{self, BoxStream, StreamExt}; use itertools::Itertools; use rover_std::errln; use tap::TapFallible; -use tokio::{sync::mpsc::UnboundedSender, task::AbortHandle}; +use tokio::{select, sync::mpsc::UnboundedSender}; +use tokio_util::sync::CancellationToken; use super::watcher::{ subgraph::{NonRepeatingFetch, SubgraphWatcher, SubgraphWatcherKind}, @@ -167,61 +168,62 @@ impl SubtaskHandleStream for SubgraphWatchers { self, sender: UnboundedSender, mut input: BoxStream<'static, Self::Input>, - ) -> AbortHandle { + cancellation_token: Option, + ) { tokio::task::spawn(async move { let mut subgraph_handles = SubgraphHandles::new( sender.clone(), self.watchers.clone(), self.resolve_introspect_subgraph_factory.clone(), self.fetch_remote_subgraph_factory.clone(), - self.supergraph_config_root.clone() + self.supergraph_config_root.clone(), ); + let cancellation_token = cancellation_token.unwrap_or_default(); + cancellation_token.run_until_cancelled(async move { + while let Some(diff) = input.next().await { + match diff { + Ok(diff) => { + // If we detect additional diffs, start a new subgraph subtask. + // Adding the abort handle to the current collection of handles. + for (subgraph_name, subgraph_config) in diff.added() { + let _ = subgraph_handles.add( + subgraph_name, + subgraph_config, + self.introspection_polling_interval + ).await.tap_err(|err| tracing::error!("{:?}", err)); + } - // Wait for supergraph diff events received from the input stream. - while let Some(diff) = input.next().await { - match diff { - Ok(diff) => { - // If we detect additional diffs, start a new subgraph subtask. - // Adding the abort handle to the current collection of handles. - for (subgraph_name, subgraph_config) in diff.added() { - let _ = subgraph_handles.add( - subgraph_name, - subgraph_config, - self.introspection_polling_interval - ).await.tap_err(|err| tracing::error!("{:?}", err)); - } - - for (subgraph_name, subgraph_config) in diff.changed() { - let _ = subgraph_handles.update( - subgraph_name, - subgraph_config, - self.introspection_polling_interval - ).await.tap_err(|err| tracing::error!("{:?}", err)); - } + for (subgraph_name, subgraph_config) in diff.changed() { + let _ = subgraph_handles.update( + subgraph_name, + subgraph_config, + self.introspection_polling_interval + ).await.tap_err(|err| tracing::error!("{:?}", err)); + } - // If we detect removal diffs, stop the subtask for the removed subgraph. - for subgraph_name in diff.removed() { - eprintln!("Removing subgraph from session: `{}`", subgraph_name); - subgraph_handles.remove(subgraph_name); + // If we detect removal diffs, stop the subtask for the removed subgraph. + for subgraph_name in diff.removed() { + eprintln!("Removing subgraph from session: `{}`", subgraph_name); + subgraph_handles.remove(subgraph_name); + } } - } - Err(errs) => { - if let SupergraphConfigSerialisationError::ResolvingSubgraphErrors(errs) = errs { - for (subgraph_name, _) in errs { - errln!("Error detected with the config for {}. Removing it from the session.", subgraph_name); - subgraph_handles.remove(&subgraph_name); + Err(errs) => { + if let SupergraphConfigSerialisationError::ResolvingSubgraphErrors(errs) = errs { + for (subgraph_name, _) in errs { + errln!("Error detected with the config for {}. Removing it from the session.", subgraph_name); + subgraph_handles.remove(&subgraph_name); + } } } } } - } - }) - .abort_handle() + }).await + }); } } struct SubgraphHandles { - abort_handles: HashMap, + abort_handles: HashMap, sender: UnboundedSender, resolve_introspect_subgraph_factory: ResolveIntrospectSubgraphFactory, fetch_remote_subgraph_factory: FetchRemoteSubgraphFactory, @@ -244,20 +246,24 @@ impl SubgraphHandles { // shut down. for (subgraph_name, watcher) in watchers.into_iter() { let (mut messages, subtask) = Subtask::<_, FullyResolvedSubgraph>::new(watcher); - let messages_abort_handle = tokio::task::spawn({ + let cancellation_token = CancellationToken::new(); + let sender = sender.clone(); + subtask.run(Some(cancellation_token.clone())); + abort_handles.insert(subgraph_name, cancellation_token.clone()); + tokio::task::spawn(async move { let sender = sender.clone(); - async move { - while let Some(subgraph) = messages.next().await { - tracing::info!("Subgraph change detected: {:?}", subgraph); - let _ = sender - .send(Subgraph(SubgraphEvent::SubgraphChanged(subgraph.into()))) - .tap_err(|err| tracing::error!("{:?}", err)); - } - } - }) - .abort_handle(); - let subtask_abort_handle = subtask.run(); - abort_handles.insert(subgraph_name, (messages_abort_handle, subtask_abort_handle)); + let cancellation_token = cancellation_token.clone(); + cancellation_token + .run_until_cancelled(async move { + while let Some(subgraph) = messages.next().await { + tracing::info!("Subgraph change detected: {:?}", subgraph); + let _ = sender + .send(Subgraph(SubgraphEvent::SubgraphChanged(subgraph.into()))) + .tap_err(|err| tracing::error!("{:?}", err)); + } + }) + .await; + }); } SubgraphHandles { sender, @@ -353,9 +359,8 @@ impl SubgraphHandles { } pub fn remove(&mut self, subgraph: &str) { - if let Some(abort_handle) = self.abort_handles.get(subgraph) { - abort_handle.0.abort(); - abort_handle.1.abort(); + if let Some(cancellation_token) = self.abort_handles.get(subgraph) { + cancellation_token.cancel(); self.abort_handles.remove(subgraph); } @@ -392,6 +397,7 @@ impl SubgraphHandles { ) -> Result<(), ResolveSubgraphError> { let fetch = subgraph_watcher.watcher().clone(); let subgraph = fetch.fetch().await?; + let cancellation_token = CancellationToken::new(); let (mut messages, subtask) = Subtask::::new(subgraph_watcher); let _ = self @@ -401,22 +407,27 @@ impl SubgraphHandles { ))) .tap_err(|err| tracing::error!("{:?}", err)); - let messages_abort_handle = tokio::spawn({ + tokio::spawn({ let sender = self.sender.clone(); + let cancellation_token = cancellation_token.clone(); async move { - while let Some(subgraph) = messages.next().await { - let _ = sender - .send(Subgraph(SubgraphEvent::SubgraphChanged(subgraph.into()))) - .tap_err(|err| tracing::error!("{:?}", err)); + loop { + select! { + _ = cancellation_token.cancelled() => break, + Some(subgraph) = messages.next() => { + tracing::info!("Subgraph change detected: {:?}", subgraph); + let _ = sender + .send(Subgraph(SubgraphEvent::SubgraphChanged(subgraph.into()))) + .tap_err(|err| tracing::error!("{:?}", err)); + } + else => break + } } } - }) - .abort_handle(); - let subtask_abort_handle = subtask.run(); - self.abort_handles.insert( - subgraph.name().to_string(), - (messages_abort_handle, subtask_abort_handle), - ); + }); + subtask.run(Some(cancellation_token.clone())); + self.abort_handles + .insert(subgraph.name().to_string(), cancellation_token); Ok(()) } } diff --git a/src/composition/watchers/watcher/file.rs b/src/composition/watchers/watcher/file.rs index 40a88f206..421397607 100644 --- a/src/composition/watchers/watcher/file.rs +++ b/src/composition/watchers/watcher/file.rs @@ -7,7 +7,7 @@ use rover_std::{errln, Fs, RoverStdError}; use tap::TapFallible; use tokio::sync::{mpsc::unbounded_channel, Mutex}; use tokio_stream::wrappers::UnboundedReceiverStream; -use tokio_util::sync::DropGuard; +use tokio_util::sync::{CancellationToken, DropGuard}; use tower::{Service, ServiceExt}; use crate::composition::supergraph::config::{ @@ -47,7 +47,7 @@ impl FileWatcher { pub async fn watch(&self) -> BoxStream<'static, String> { let (file_tx, file_rx) = unbounded_channel(); let output = UnboundedReceiverStream::new(file_rx); - let cancellation_token = Fs::watch_file(self.path.as_path().into(), file_tx); + let cancellation_token = Fs::watch_file(self.path.as_path().into(), file_tx, None); self.drop_guard .set(cancellation_token.clone().drop_guard()) .unwrap(); @@ -114,10 +114,17 @@ impl SubgraphFileWatcher { /// Development note: in the future, we might consider a way to kill the watcher when the /// rover-std::fs filewatcher dies. Right now, the stream remains active and we can /// indefinitely loop on a close filewatcher - pub async fn watch(self) -> BoxStream<'static, FullyResolvedSubgraph> { + pub async fn watch( + self, + cancellation_token: CancellationToken, + ) -> BoxStream<'static, FullyResolvedSubgraph> { let (file_tx, file_rx) = unbounded_channel(); let output = UnboundedReceiverStream::new(file_rx); - let cancellation_token = Fs::watch_file(self.path.as_path().into(), file_tx); + let cancellation_token = Fs::watch_file( + self.path.as_path().into(), + file_tx, + Some(cancellation_token), + ); { let mut drop_guard = self.drop_guard.lock().await; let _ = drop_guard.insert(cancellation_token.clone().drop_guard()); @@ -208,7 +215,7 @@ mod tests { let watcher = SubgraphFileWatcher::new(path.clone(), resolve_file_subgraph); // SubgraphFileWatcher has a DropGuard associated with it that cancels the underlying FileWatcher's CancellationToken when dropped, so we must retain a reference until this test finishes. This can be fixed if we migrate this to a `Subtask` implementation to make it safer and more explicit let _watcher = watcher.clone(); - let mut watching = watcher.watch().await; + let mut watching = watcher.watch(CancellationToken::default()).await; let _ = tokio::time::sleep(Duration::from_millis(500)).await; let mut writeable_file = OpenOptions::new() diff --git a/src/composition/watchers/watcher/introspection.rs b/src/composition/watchers/watcher/introspection.rs index f9c43782a..7f1bb023e 100644 --- a/src/composition/watchers/watcher/introspection.rs +++ b/src/composition/watchers/watcher/introspection.rs @@ -1,6 +1,8 @@ use std::{marker::Send, pin::Pin, time::Duration}; use futures::{Stream, StreamExt}; +use rover_std::{errln, infoln}; +use tokio_util::sync::CancellationToken; use tower::{Service, ServiceExt}; use crate::{ @@ -12,8 +14,6 @@ use crate::{ watch::Watch, }; -use rover_std::{errln, infoln}; - /// Subgraph introspection #[derive(Debug, Clone)] pub struct SubgraphIntrospection { @@ -36,13 +36,16 @@ impl SubgraphIntrospection { // TODO: better typing so that it's over some impl, not string; makes all watch() fns require // returning a string - pub fn watch(self) -> Pin + Send>> { + pub fn watch( + self, + cancellation_token: CancellationToken, + ) -> Pin + Send>> { let watch = Watch::builder() .polling_interval(self.polling_interval) .service(self.resolver.clone()) .build(); let (watch_messages, watch_subtask) = Subtask::new(watch); - watch_subtask.run(); + watch_subtask.run(Some(cancellation_token)); // Stream any subgraph changes, filtering out empty responses (None) while passing along // the sdl changes diff --git a/src/composition/watchers/watcher/subgraph.rs b/src/composition/watchers/watcher/subgraph.rs index b3b7b9e33..8ff765bd0 100644 --- a/src/composition/watchers/watcher/subgraph.rs +++ b/src/composition/watchers/watcher/subgraph.rs @@ -5,8 +5,10 @@ use futures::{stream::BoxStream, StreamExt}; use rover_client::operations::subgraph::introspect::SubgraphIntrospectError; use rover_std::{infoln, RoverStdError}; use tap::TapFallible; -use tokio::{sync::mpsc::UnboundedSender, task::AbortHandle}; +use tokio::sync::mpsc::UnboundedSender; +use tokio_util::sync::CancellationToken; use tower::{Service, ServiceExt}; +use tracing::debug; use super::{file::SubgraphFileWatcher, introspection::SubgraphIntrospection}; use crate::{ @@ -115,10 +117,13 @@ impl SubgraphWatcherKind { /// /// Development note: this is a stream of Strings, but in the future we might want something /// more flexible to get type safety. - async fn watch(self) -> Option> { + async fn watch( + self, + cancellation_token: CancellationToken, + ) -> Option> { match self { - Self::File(file_watcher) => Some(file_watcher.watch().await), - Self::Introspect(introspection) => Some(introspection.watch()), + Self::File(file_watcher) => Some(file_watcher.watch(cancellation_token.clone()).await), + Self::Introspect(introspection) => Some(introspection.watch(cancellation_token)), kind => { tracing::debug!("{kind:?} is not watchable. Skipping"); None @@ -142,18 +147,28 @@ impl SubgraphWatcherKind { impl SubtaskHandleUnit for SubgraphWatcher { type Output = FullyResolvedSubgraph; - fn handle(self, sender: UnboundedSender) -> AbortHandle { + fn handle( + self, + sender: UnboundedSender, + cancellation_token: Option, + ) { let watcher = self.watcher.clone(); + let cancellation_token = cancellation_token.unwrap_or_default(); tokio::spawn(async move { - let stream = watcher.watch().await; + let stream = watcher.watch(cancellation_token.clone()).await; if let Some(mut stream) = stream { - while let Some(subgraph) = stream.next().await { - let _ = sender - .send(subgraph) - .tap_err(|err| tracing::error!("{:?}", err)); - } + debug!("Watching subgraph"); + cancellation_token + .run_until_cancelled(async move { + while let Some(subgraph) = stream.next().await { + let _ = sender + .send(subgraph) + .tap_err(|err| tracing::error!("{:?}", err)); + } + }) + .await; + debug!("Watching is now finished!"); } - }) - .abort_handle() + }); } } diff --git a/src/composition/watchers/watcher/supergraph_config.rs b/src/composition/watchers/watcher/supergraph_config.rs index 2479fd012..cb290b42f 100644 --- a/src/composition/watchers/watcher/supergraph_config.rs +++ b/src/composition/watchers/watcher/supergraph_config.rs @@ -11,7 +11,7 @@ use rover_std::errln; use tap::TapFallible; use thiserror::Error; use tokio::sync::broadcast::Sender; -use tokio::task::AbortHandle; +use tokio_util::sync::CancellationToken; use tracing::debug; use super::file::FileWatcher; @@ -44,88 +44,89 @@ impl SupergraphConfigWatcher { impl SubtaskHandleMultiStream for SupergraphConfigWatcher { type Output = Result; - fn handle(self, sender: Sender) -> AbortHandle { + fn handle(self, sender: Sender, cancellation_token: Option) { tracing::warn!("Running SupergraphConfigWatcher"); let supergraph_config_path = self.file_watcher.path().clone(); - tokio::spawn( - async move { - let supergraph_config_path = supergraph_config_path.clone(); - let mut latest_supergraph_config = self.supergraph_config.clone(); - let mut stream = self.file_watcher.watch().await; - while let Some(contents) = stream.next().await { - eprintln!("{} changed. Applying changes to the session.", supergraph_config_path); - tracing::info!( - "{} changed. Parsing it as a `SupergraphConfig`", - supergraph_config_path - ); - match SupergraphConfig::new_from_yaml(&contents) { - Ok(supergraph_config) => { - let subgraphs = BTreeMap::from_iter(supergraph_config.clone().into_iter()); - let unresolved_supergraph_config = UnresolvedSupergraphConfig::builder() - .origin_path(supergraph_config_path.clone()) - .subgraphs(subgraphs) - .federation_version_resolver(FederationVersionResolver::default().from_supergraph_config(Some(&supergraph_config))) - .build(); - let supergraph_config = LazilyResolvedSupergraphConfig::resolve( - &supergraph_config_path.parent().unwrap().to_path_buf(), - unresolved_supergraph_config, - ).await.map(SupergraphConfig::from); - - match supergraph_config { - Ok(supergraph_config) => { - let supergraph_config_diff = SupergraphConfigDiff::new( - &latest_supergraph_config, - supergraph_config.clone(), - ); - match supergraph_config_diff { - Ok(supergraph_config_diff) => { - debug!("{supergraph_config_diff}"); - let _ = sender - .send(Ok(supergraph_config_diff)) - .tap_err(|err| tracing::error!("{:?}", err)); + let cancellation_token = cancellation_token.unwrap_or_default(); + tokio::spawn(async move { + let supergraph_config_path = supergraph_config_path.clone(); + let mut latest_supergraph_config = self.supergraph_config.clone(); + let mut stream = self.file_watcher.watch().await; + cancellation_token.run_until_cancelled(async move { + while let Some(contents) = stream.next().await { + eprintln!("{} changed. Applying changes to the session.", supergraph_config_path); + tracing::info!( + "{} changed. Parsing it as a `SupergraphConfig`", + supergraph_config_path + ); + match SupergraphConfig::new_from_yaml(&contents) { + Ok(supergraph_config) => { + let subgraphs = BTreeMap::from_iter(supergraph_config.clone().into_iter()); + let unresolved_supergraph_config = UnresolvedSupergraphConfig::builder() + .origin_path(supergraph_config_path.clone()) + .subgraphs(subgraphs) + .federation_version_resolver(FederationVersionResolver::default().from_supergraph_config(Some(&supergraph_config))) + .build(); + let supergraph_config = LazilyResolvedSupergraphConfig::resolve( + &supergraph_config_path.parent().unwrap().to_path_buf(), + unresolved_supergraph_config, + ).await.map(SupergraphConfig::from); + + match supergraph_config { + Ok(supergraph_config) => { + let supergraph_config_diff = SupergraphConfigDiff::new( + &latest_supergraph_config, + supergraph_config.clone(), + ); + match supergraph_config_diff { + Ok(supergraph_config_diff) => { + debug!("{supergraph_config_diff}"); + let _ = sender + .send(Ok(supergraph_config_diff)) + .tap_err(|err| tracing::error!("{:?}", err)); + } + Err(err) => { + tracing::error!("Failed to construct a diff between the current and previous `SupergraphConfig`s.\n{}", err); + } } - Err(err) => { - tracing::error!("Failed to construct a diff between the current and previous `SupergraphConfig`s.\n{}", err); - } - } - latest_supergraph_config = supergraph_config; - } - Err(resolution_errors) => { - errln!( - "Failed to lazily resolve the supergraph config at {}.\n{}", - supergraph_config_path, - itertools::join( - resolution_errors - .iter() - .map( - |(name, err)| format!("{}: {}", name, err) - ), - "\n") - ); - // Since we have errors we need to remove these subgraphs from - // what we're tracking **and** emit events to make sure they - // get removed in downstream processes as well. - latest_supergraph_config = remove_errored_subgraphs(latest_supergraph_config, resolution_errors.clone()); - let _ = sender - .send(Err(SupergraphConfigSerialisationError::ResolvingSubgraphErrors(resolution_errors))) - .tap_err(|err| tracing::error!("{:?}", err)); + latest_supergraph_config = supergraph_config; + } + Err(resolution_errors) => { + errln!( + "Failed to lazily resolve the supergraph config at {}.\n{}", + supergraph_config_path, + itertools::join( + resolution_errors + .iter() + .map( + |(name, err)| format!("{}: {}", name, err) + ), + "\n") + ); + // Since we have errors we need to remove these subgraphs from + // what we're tracking **and** emit events to make sure they + // get removed in downstream processes as well. + latest_supergraph_config = remove_errored_subgraphs(latest_supergraph_config, resolution_errors.clone()); + let _ = sender + .send(Err(SupergraphConfigSerialisationError::ResolvingSubgraphErrors(resolution_errors))) + .tap_err(|err| tracing::error!("{:?}", err)); + } } } - } - Err(err) => { - tracing::error!("could not parse supergraph config file: {:?}", err); - errln!("Could not parse supergraph config file.\n{}", err); - let _ = sender - .send(Err(DeserializingConfigError { - source: Arc::new(err) - })) - .tap_err(|err| tracing::error!("{:?}", err)); + Err(err) => { + tracing::error!("could not parse supergraph config file: {:?}", err); + errln!("Could not parse supergraph config file.\n{}", err); + let _ = sender + .send(Err(DeserializingConfigError { + source: Arc::new(err) + })) + .tap_err(|err| tracing::error!("{:?}", err)); + } } } - } - }) - .abort_handle() + }).await; + }); } } diff --git a/src/subtask.rs b/src/subtask.rs index 66b7157dc..3bb1d2129 100644 --- a/src/subtask.rs +++ b/src/subtask.rs @@ -58,16 +58,18 @@ use futures::stream::BoxStream; use tokio::sync::broadcast; use tokio::sync::broadcast::Sender; -use tokio::{ - sync::mpsc::{unbounded_channel, UnboundedSender}, - task::AbortHandle, -}; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio_stream::wrappers::{BroadcastStream, UnboundedReceiverStream}; +use tokio_util::sync::CancellationToken; /// A trait whose implementation will be able to send events pub trait SubtaskHandleUnit { type Output; - fn handle(self, sender: UnboundedSender) -> AbortHandle; + fn handle( + self, + sender: UnboundedSender, + cancellation_token: Option, + ); } /// A trait whose implementation will be able to both send and receive events @@ -78,25 +80,30 @@ pub trait SubtaskHandleStream { self, sender: UnboundedSender, input: BoxStream<'static, Self::Input>, - ) -> AbortHandle; + cancellation_token: Option, + ); } /// A trait whose implementation will be able to send events to multiple channels with /// broadcast semantics. pub trait SubtaskHandleMultiStream { type Output; - fn handle(self, sender: Sender) -> AbortHandle; + fn handle(self, sender: Sender, cancellation_token: Option); } /// A trait whose implementation can run a subtask that only ingests messages pub trait SubtaskRunUnit { - fn run(self) -> AbortHandle; + fn run(self, cancellation_token: Option); } /// A trait whose implementation can run a subtask that can both ingest messages and emit them pub trait SubtaskRunStream { type Input; - fn run(self, input: BoxStream<'static, Self::Input>) -> AbortHandle; + fn run( + self, + input: BoxStream<'static, Self::Input>, + cancellation_token: Option, + ); } /// A background task that can emit messages via a sender channel @@ -144,16 +151,20 @@ impl BroadcastSubtask { impl, Output> SubtaskRunUnit for Subtask { /// Begin running the subtask, calling handle() on the type implementing the SubTaskHandleUnit trait - fn run(self) -> AbortHandle { - self.inner.handle(self.sender) + fn run(self, cancellation_token: Option) { + self.inner.handle(self.sender, cancellation_token) } } impl, Output> SubtaskRunStream for Subtask { type Input = T::Input; /// Begin running the subtask with a stream of events, calling handle() on the type implementing the SubTaskHandleStream trait - fn run(self, input: BoxStream<'static, Self::Input>) -> AbortHandle { - self.inner.handle(self.sender, input) + fn run( + self, + input: BoxStream<'static, Self::Input>, + cancellation_token: Option, + ) { + self.inner.handle(self.sender, input, cancellation_token) } } @@ -161,7 +172,7 @@ impl, Output> SubtaskRunUnit for BroadcastSubtask { /// Begin running the subtask, calling handle() on the type implementing the SubTaskHandleUnit trait - fn run(self) -> AbortHandle { - self.inner.handle(self.sender) + fn run(self, cancellation_token: Option) { + self.inner.handle(self.sender, cancellation_token) } } diff --git a/src/watch.rs b/src/watch.rs index 16c42d193..482af4530 100644 --- a/src/watch.rs +++ b/src/watch.rs @@ -2,6 +2,7 @@ use std::time::Duration; use buildstructor::Builder; use tap::TapFallible; +use tokio_util::sync::CancellationToken; use tower::{Service, ServiceExt}; use crate::subtask::SubtaskHandleUnit; @@ -26,46 +27,52 @@ where fn handle( self, sender: tokio::sync::mpsc::UnboundedSender, - ) -> tokio::task::AbortHandle { + cancellation_token: Option, + ) { let mut service = self.service.clone(); + let cancellation_token = cancellation_token.unwrap_or_default(); tokio::task::spawn(async move { - let service = service.ready().await.unwrap(); - let mut last_result: Option> = None; - loop { - match service.call(()).await { - Ok(output) => { - let mut was_updated = true; - if let Some(Ok(last)) = last_result { - if last == output { - was_updated = false + let cancellation_token = cancellation_token.clone(); + cancellation_token + .run_until_cancelled(async move { + let service = service.ready().await.unwrap(); + let mut last_result: Option> = None; + loop { + match service.call(()).await { + Ok(output) => { + let mut was_updated = true; + if let Some(Ok(last)) = last_result { + if last == output { + was_updated = false + } + } + if was_updated { + let _ = sender + .send(Ok(output.clone())) + .tap_err(|err| tracing::error!("{:?}", err)); + } + last_result = Some(Ok(output)); } - } - if was_updated { - let _ = sender - .send(Ok(output.clone())) - .tap_err(|err| tracing::error!("{:?}", err)); - } - last_result = Some(Ok(output)); - } - Err(error) => { - let mut was_updated = true; - let e = error.to_string(); - if let Some(Err(last)) = last_result { - if last == e { - was_updated = false; + Err(error) => { + let mut was_updated = true; + let e = error.to_string(); + if let Some(Err(last)) = last_result { + if last == e { + was_updated = false; + } + } + if was_updated { + let _ = sender + .send(Err(error)) + .tap_err(|err| tracing::error!("{:?}", err)); + } + last_result = Some(Err(e)); } } - if was_updated { - let _ = sender - .send(Err(error)) - .tap_err(|err| tracing::error!("{:?}", err)); - } - last_result = Some(Err(e)); + tokio::time::sleep(self.polling_interval).await } - } - tokio::time::sleep(self.polling_interval).await - } - }) - .abort_handle() + }) + .await; + }); } }