Skip to content

Commit

Permalink
Add is_locked_serially() to check if we are in a #[serial] context
Browse files Browse the repository at this point in the history
  • Loading branch information
pgerber committed Jul 12, 2024
1 parent b39310b commit 1a1ab58
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 0 deletions.
124 changes: 124 additions & 0 deletions serial_test/src/code_lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ impl UniqueReentrantMutex {
pub fn is_locked(&self) -> bool {
self.locks.is_locked()
}

pub fn is_locked_by_current_thread(&self) -> bool {
self.locks.is_locked_by_current_thread()
}
}

#[inline]
Expand All @@ -44,6 +48,63 @@ pub(crate) fn global_locks() -> &'static HashMap<String, UniqueReentrantMutex> {
LOCKS.get_or_init(HashMap::new)
}

/// Check if the current thread is holding a serial lock
///
/// Can be used to assert that a piece of code can only be called
/// from a test marked `#[serial]`.
///
/// Example, with `#[serial]`:
///
/// ```
/// use serial_test::{is_locked_serially, serial};
///
/// fn do_something_in_need_of_serialization() {
/// assert!(is_locked_serially(None));
///
/// // ...
/// }
///
/// #[test]
/// # fn unused() {}
/// #[serial]
/// fn main() {
/// do_something_in_need_of_serialization();
/// }
/// ```
///
/// Example, missing `#[serial]`:
///
/// ```should_panic
/// use serial_test::{is_locked_serially, serial};
///
/// #[test]
/// # fn unused() {}
/// // #[serial] // <-- missing
/// fn main() {
/// assert!(is_locked_serially(None));
/// }
/// ```
///
/// Example, `#[test(some_key)]`:
///
/// ```
/// use serial_test::{is_locked_serially, serial};
///
/// #[test]
/// # fn unused() {}
/// #[serial(some_key)]
/// fn main() {
/// assert!(is_locked_serially(Some("some_key")));
/// assert!(!is_locked_serially(None));
/// }
/// ```
pub fn is_locked_serially(name: Option<&str>) -> bool {
global_locks()
.get(name.unwrap_or_default())
.map(|lock| lock.get().is_locked_by_current_thread())
.unwrap_or_default()
}

static MUTEX_ID: AtomicU32 = AtomicU32::new(1);

impl UniqueReentrantMutex {
Expand All @@ -68,3 +129,66 @@ pub(crate) fn check_new_key(name: &str) {
Entry::Vacant(v) => v.insert_entry(UniqueReentrantMutex::new_mutex(name)),
};
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{local_parallel_core, local_serial_core};

const NAME1: &str = "NAME1";
const NAME2: &str = "NAME2";

#[test]
fn assert_serially_locked_without_name() {
local_serial_core(vec![""], None, || {
assert!(is_locked_serially(None));
assert!(!is_locked_serially(Some("no_such_name")));
});
}

#[test]
fn assert_serially_locked_with_multiple_names() {
local_serial_core(vec![NAME1, NAME2], None, || {
assert!(is_locked_serially(Some(NAME1)));
assert!(is_locked_serially(Some(NAME2)));
assert!(!is_locked_serially(Some("no_such_name")));
assert!(!is_locked_serially(None));
});
}

#[test]
fn assert_serially_locked_when_actually_locked_parallel() {
local_parallel_core(vec![NAME1, NAME2], None, || {
assert!(!is_locked_serially(Some(NAME1)));
assert!(!is_locked_serially(Some(NAME2)));
assert!(!is_locked_serially(Some("no_such_name")));
assert!(!is_locked_serially(None));
});
}

#[test]
fn assert_serially_locked_outside_serial_lock() {
assert!(!is_locked_serially(Some(NAME1)));
assert!(!is_locked_serially(Some(NAME2)));
assert!(!is_locked_serially(None));

local_serial_core(vec![NAME1], None, || {
// ...
});

assert!(!is_locked_serially(Some(NAME1)));
assert!(!is_locked_serially(Some(NAME2)));
assert!(!is_locked_serially(None));
}

#[test]
fn assert_serially_locked_in_different_thread() {
local_serial_core(vec![NAME1, NAME2], None, || {
std::thread::spawn(|| {
assert!(!is_locked_serially(Some(NAME2)));
})
.join()
.unwrap();
});
}
}
2 changes: 2 additions & 0 deletions serial_test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,5 @@ pub use serial_test_derive::{parallel, serial};

#[cfg(feature = "file_locks")]
pub use serial_test_derive::{file_parallel, file_serial};

pub use code_lock::is_locked_serially;
4 changes: 4 additions & 0 deletions serial_test/src/rwlock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ impl Locks {
self.arc.serial.is_locked()
}

pub fn is_locked_by_current_thread(&self) -> bool {
self.arc.serial.is_owned_by_current_thread()
}

pub fn serial(&self) -> MutexGuardWrapper {
#[cfg(feature = "logging")]
debug!("Get serial lock '{}'", self.name);
Expand Down

0 comments on commit 1a1ab58

Please sign in to comment.