diff --git a/datafusion/common-runtime/src/common.rs b/datafusion/common-runtime/src/common.rs index eeff2f5a0ab4..295867a78f2d 100644 --- a/datafusion/common-runtime/src/common.rs +++ b/datafusion/common-runtime/src/common.rs @@ -79,48 +79,29 @@ impl SpawnedTask { mod tests { use super::*; - use std::{ - future::{pending, Pending}, - sync::{Arc, Mutex}, - }; + use std::future::{pending, Pending}; use tokio::runtime::Runtime; #[tokio::test] async fn runtime_shutdown() { - // capture the panic message - let panic_msg = Arc::new(Mutex::new(None)); - let captured_panic_msg = Arc::clone(&panic_msg); - std::panic::set_hook(Box::new(move |e| { - let mut guard = captured_panic_msg.lock().unwrap(); - *guard = Some(e.to_string()); - })); - - for _ in 0..30 { - let rt = Runtime::new().unwrap(); - let join = rt.spawn(async { - let task = SpawnedTask::spawn(async { + let rt = Runtime::new().unwrap(); + let task = rt + .spawn(async { + SpawnedTask::spawn(async { let fut: Pending<()> = pending(); fut.await; unreachable!("should never return"); - }); - let _ = task.join_unwind().await; - }); - - // caller shutdown their DF runtime (e.g. timeout, error in caller, etc) - rt.shutdown_background(); + }) + }) + .await + .unwrap(); - // race condition - // poll occurs during shutdown (buffered stream poll calls, etc) - let _ = join.await; - } + // caller shutdown their DF runtime (e.g. timeout, error in caller, etc) + rt.shutdown_background(); - // demonstrate that we hit the unreachable code - let maybe_panic = panic_msg.lock().unwrap().clone(); - assert_eq!( - maybe_panic, None, - "should not have rt thread panic, instead found {:?}", - maybe_panic - ); + // race condition + // poll occurs during shutdown (buffered stream poll calls, etc) + let _ = task.join_unwind().await; } }