From d338040837713a1e0c668b1ed9f72331484a2732 Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Mon, 20 Jun 2022 18:55:51 +0100 Subject: [PATCH 1/2] Cope better with asserts inside parallel blocks --- Cargo.lock | 73 ++++++++++++-- serial_test/Cargo.toml | 2 + serial_test/src/code_lock.rs | 10 ++ serial_test/src/parallel_code_lock.rs | 131 ++++++++++++++++++++++++-- serial_test/src/rwlock.rs | 11 +++ serial_test/src/serial_code_lock.rs | 12 +++ 6 files changed, 224 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b895e9c..08815aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -96,28 +96,81 @@ dependencies = [ "winapi", ] +[[package]] +name = "futures" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f73fe65f54d1e12b726f517d3e2135ca3125a437b6d998caf1962961f7172d9e" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3083ce4b914124575708913bca19bfe887522d6e2e6d0952943f5eac4a74010" +dependencies = [ + "futures-core", + "futures-sink", +] + [[package]] name = "futures-core" -version = "0.3.15" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c09fd04b7e4073ac7156a9539b57a484a8ea920f79c7c675d05d289ab6110d3" + +[[package]] +name = "futures-executor" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0402f765d8a89a26043b889b26ce3c4679d268fa6bb22cd7c6aad98340e179d1" +checksum = "9420b90cfa29e327d0429f19be13e7ddb68fa1cccb09d65e5706b8c7a749b8a6" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc4045962a5a5e935ee2fdedaa4e08284547402885ab326734432bed5d12966b" + +[[package]] +name = "futures-sink" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21163e139fa306126e6eedaf49ecdb4588f939600f0b1e770f4205ee4b7fa868" [[package]] name = "futures-task" -version = "0.3.15" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a16bef9fc1a4dddb5bee51c989e3fbba26569cbb0e31f5b303c184e3dd33dae" +checksum = "57c66a976bf5909d801bbef33416c41372779507e7a6b3a5e25e4749c58f776a" [[package]] name = "futures-util" -version = "0.3.14" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c144ad54d60f23927f0a6b6d816e4271278b64f005ad65e4e35291d2de9c025" +checksum = "d8b7abd5d659d9b90c8cba917f6ec750a74e2dc23902ef9cd4cc8c8b22e6036a" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -388,11 +441,13 @@ version = "0.7.0" dependencies = [ "document-features", "fslock", + "futures", "itertools", "lazy_static", "log", "parking_lot", "serial_test_derive", + "tokio", ] [[package]] @@ -430,6 +485,12 @@ dependencies = [ "libc", ] +[[package]] +name = "slab" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb703cfe953bccee95685111adeedb76fabe4e97549a58d16f03ea7b9367bb32" + [[package]] name = "smallvec" version = "1.8.0" diff --git a/serial_test/Cargo.toml b/serial_test/Cargo.toml index db25746..4d1a0b9 100644 --- a/serial_test/Cargo.toml +++ b/serial_test/Cargo.toml @@ -17,9 +17,11 @@ serial_test_derive = { version = "~0.7.0", path = "../serial_test_derive" } fslock = {version = "0.2", optional = true} document-features = {version = "0.2", optional=true} log = "0.4" +futures = {version = "^0.3", default_features = false, features = ["executor"] } [dev-dependencies] itertools = "0.10" +tokio = { version = "^1.17", features = ["macros", "rt"] } [features] default = [] diff --git a/serial_test/src/code_lock.rs b/serial_test/src/code_lock.rs index 24534c4..9bd99a7 100644 --- a/serial_test/src/code_lock.rs +++ b/serial_test/src/code_lock.rs @@ -30,6 +30,16 @@ impl UniqueReentrantMutex { pub(crate) fn end_parallel(&self) { self.locks.end_parallel(); } + + #[cfg(test)] + pub fn parallel_count(&self) -> u32 { + self.locks.parallel_count() + } + + #[cfg(test)] + pub fn is_locked(&self) -> bool { + self.locks.is_locked() + } } lazy_static! { diff --git a/serial_test/src/parallel_code_lock.rs b/serial_test/src/parallel_code_lock.rs index e521389..1f48517 100644 --- a/serial_test/src/parallel_code_lock.rs +++ b/serial_test/src/parallel_code_lock.rs @@ -1,7 +1,8 @@ #![allow(clippy::await_holding_lock)] use crate::code_lock::{check_new_key, LOCKS}; -use std::ops::Deref; +use futures::FutureExt; +use std::{ops::Deref, panic}; #[doc(hidden)] pub fn local_parallel_core_with_return( @@ -12,9 +13,14 @@ pub fn local_parallel_core_with_return( let unlock = LOCKS.read_recursive(); unlock.deref()[name].start_parallel(); - let ret = function(); + let res = panic::catch_unwind(|| function()); unlock.deref()[name].end_parallel(); - ret + match res { + Ok(ret) => ret, + Err(err) => { + panic::resume_unwind(err); + } + } } #[doc(hidden)] @@ -23,30 +29,137 @@ pub fn local_parallel_core(name: &str, function: fn()) { let unlock = LOCKS.read_recursive(); unlock.deref()[name].start_parallel(); - function(); + let res = panic::catch_unwind(|| { + function(); + }); unlock.deref()[name].end_parallel(); + if let Err(err) = res { + panic::resume_unwind(err); + } } #[doc(hidden)] pub async fn local_async_parallel_core_with_return( name: &str, - fut: impl std::future::Future>, + fut: impl std::future::Future> + panic::UnwindSafe, ) -> Result<(), E> { check_new_key(name); let unlock = LOCKS.read_recursive(); unlock.deref()[name].start_parallel(); - let ret = fut.await; + let res = fut.catch_unwind().await; unlock.deref()[name].end_parallel(); - ret + match res { + Ok(ret) => ret, + Err(err) => { + panic::resume_unwind(err); + } + } } #[doc(hidden)] -pub async fn local_async_parallel_core(name: &str, fut: impl std::future::Future) { +pub async fn local_async_parallel_core( + name: &str, + fut: impl std::future::Future + panic::UnwindSafe, +) { check_new_key(name); let unlock = LOCKS.read_recursive(); unlock.deref()[name].start_parallel(); - fut.await; + let res = fut.catch_unwind().await; unlock.deref()[name].end_parallel(); + if let Err(err) = res { + panic::resume_unwind(err); + } +} + +#[cfg(test)] +mod tests { + use crate::{ + code_lock::LOCKS, local_async_parallel_core, local_async_parallel_core_with_return, + local_parallel_core, local_parallel_core_with_return, + }; + use std::{io::Error, ops::Deref, panic}; + + #[test] + fn unlock_on_assert_sync_without_return() { + let _ = panic::catch_unwind(|| { + local_parallel_core("unlock_on_assert_sync_without_return", || { + assert!(false); + }) + }); + let unlock = LOCKS.read_recursive(); + assert_eq!( + unlock.deref()["unlock_on_assert_sync_without_return"].parallel_count(), + 0 + ); + } + + #[test] + fn unlock_on_assert_sync_with_return() { + let _ = panic::catch_unwind(|| { + local_parallel_core_with_return( + "unlock_on_assert_sync_with_return", + || -> Result<(), Error> { + assert!(false); + Ok(()) + }, + ) + }); + let unlock = LOCKS.read_recursive(); + assert_eq!( + unlock.deref()["unlock_on_assert_sync_with_return"].parallel_count(), + 0 + ); + } + + #[tokio::test] + async fn unlock_on_assert_async_without_return() { + async fn demo_assert() { + assert!(false); + } + async fn call_serial_test_fn() { + local_async_parallel_core("unlock_on_assert_async_without_return", demo_assert()).await + } + // as per https://stackoverflow.com/a/66529014/320546 + let _ = panic::catch_unwind(|| { + let handle = tokio::runtime::Handle::current(); + let _enter_guard = handle.enter(); + futures::executor::block_on(call_serial_test_fn()); + }); + let unlock = LOCKS.read_recursive(); + assert_eq!( + unlock.deref()["unlock_on_assert_async_without_return"].parallel_count(), + 0 + ); + } + + #[tokio::test] + async fn unlock_on_assert_async_with_return() { + async fn demo_assert() -> Result<(), Error> { + assert!(false); + Ok(()) + } + + #[allow(unused_must_use)] + async fn call_serial_test_fn() { + local_async_parallel_core_with_return( + "unlock_on_assert_async_with_return", + demo_assert(), + ) + .await; + } + + // as per https://stackoverflow.com/a/66529014/320546 + let _ = panic::catch_unwind(|| { + let handle = tokio::runtime::Handle::current(); + let _enter_guard = handle.enter(); + futures::executor::block_on(call_serial_test_fn()); + }); + let unlock = LOCKS.read_recursive(); + assert_eq!( + unlock.deref()["unlock_on_assert_async_with_return"].parallel_count(), + 0 + ); + } } diff --git a/serial_test/src/rwlock.rs b/serial_test/src/rwlock.rs index 13d8ab2..c61944b 100644 --- a/serial_test/src/rwlock.rs +++ b/serial_test/src/rwlock.rs @@ -39,6 +39,11 @@ impl Locks { } } + #[cfg(test)] + pub fn is_locked(&self) -> bool { + self.arc.serial.is_locked() + } + pub fn serial(&self) -> MutexGuardWrapper { let mut lock_state = self.arc.mutex.lock(); loop { @@ -88,4 +93,10 @@ impl Locks { drop(lock_state); self.arc.condvar.notify_one(); } + + #[cfg(test)] + pub fn parallel_count(&self) -> u32 { + let lock_state = self.arc.mutex.lock(); + lock_state.parallels + } } diff --git a/serial_test/src/serial_code_lock.rs b/serial_test/src/serial_code_lock.rs index 5bd689d..853d9ac 100644 --- a/serial_test/src/serial_code_lock.rs +++ b/serial_test/src/serial_code_lock.rs @@ -51,6 +51,7 @@ pub async fn local_async_serial_core(name: &str, fut: impl std::future::Future Date: Mon, 20 Jun 2022 19:07:41 +0100 Subject: [PATCH 2/2] Clippy improvement to local_parallel_core_with_return --- serial_test/src/parallel_code_lock.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serial_test/src/parallel_code_lock.rs b/serial_test/src/parallel_code_lock.rs index 1f48517..e854cec 100644 --- a/serial_test/src/parallel_code_lock.rs +++ b/serial_test/src/parallel_code_lock.rs @@ -13,7 +13,7 @@ pub fn local_parallel_core_with_return( let unlock = LOCKS.read_recursive(); unlock.deref()[name].start_parallel(); - let res = panic::catch_unwind(|| function()); + let res = panic::catch_unwind(function); unlock.deref()[name].end_parallel(); match res { Ok(ret) => ret,