From c0f08c289fd52b1d52ca4bddc45af12438a04aed Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Wed, 8 Jan 2025 15:57:37 -0700 Subject: [PATCH] Implement locked iteration for PyList (#4789) * implement locked iteration for PyList * fix limited API and PyPy support * fix formatting of safety docstrings * only define fold and rfold on not(feature = "nightly") * add missing try_fold implementation on nightly * Use split borrows for locked iteration for PyList Inline ListIterImpl implementations by using split borrows and destructuring let Self { .. } = self destructuring inside BoundListIterator impls. Signed-off-by: Manos Pitsidianakis * use a function to do the split borrow * add changelog entries * fix clippy on limited API and PyPy * use a macro for the split borrow * add a test that mutates the list during a fold * enable next_unchecked on PyPy * fix incorrect docstring for locked_for_each * simplify borrows by adding BoundListIterator::with_critical_section * fix build on GIL-enabled and limited API builds * fix docs build on MSRV --------- Signed-off-by: Manos Pitsidianakis Co-authored-by: Manos Pitsidianakis --- newsfragments/4789.added.md | 3 + newsfragments/4789.changed.md | 2 + src/types/list.rs | 482 +++++++++++++++++++++++++++++++--- 3 files changed, 454 insertions(+), 33 deletions(-) create mode 100644 newsfragments/4789.added.md create mode 100644 newsfragments/4789.changed.md diff --git a/newsfragments/4789.added.md b/newsfragments/4789.added.md new file mode 100644 index 00000000000..fab564a8962 --- /dev/null +++ b/newsfragments/4789.added.md @@ -0,0 +1,3 @@ +* Added `PyList::locked_for_each`, which is equivalent to `PyList::for_each` on + the GIL-enabled build and uses a critical section to lock the list on the + free-threaded build, similar to `PyDict::locked_for_each`. diff --git a/newsfragments/4789.changed.md b/newsfragments/4789.changed.md new file mode 100644 index 00000000000..d20419e8f23 --- /dev/null +++ b/newsfragments/4789.changed.md @@ -0,0 +1,2 @@ +* Operations that process a PyList via an iterator now use a critical section + on the free-threaded build to amortize synchronization cost and prevent race conditions. diff --git a/src/types/list.rs b/src/types/list.rs index 2e124c82400..76da36d00b9 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -179,7 +179,9 @@ pub trait PyListMethods<'py>: crate::sealed::Sealed { /// # Safety /// /// Caller must verify that the index is within the bounds of the list. - #[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))] + /// On the free-threaded build, caller must verify they have exclusive access to the list + /// via a lock or by holding the innermost critical section on the list. + #[cfg(not(Py_LIMITED_API))] unsafe fn get_item_unchecked(&self, index: usize) -> Bound<'py, PyAny>; /// Takes the slice `self[low:high]` and returns it as a new list. @@ -239,6 +241,17 @@ pub trait PyListMethods<'py>: crate::sealed::Sealed { /// Returns an iterator over this list's items. fn iter(&self) -> BoundListIterator<'py>; + /// Iterates over the contents of this list while holding a critical section on the list. + /// This is useful when the GIL is disabled and the list is shared between threads. + /// It is not guaranteed that the list will not be modified during iteration when the + /// closure calls arbitrary Python code that releases the critical section held by the + /// iterator. Otherwise, the list will not be modified during iteration. + /// + /// This is equivalent to for_each if the GIL is enabled. + fn locked_for_each(&self, closure: F) -> PyResult<()> + where + F: Fn(Bound<'py, PyAny>) -> PyResult<()>; + /// Sorts the list in-place. Equivalent to the Python expression `l.sort()`. fn sort(&self) -> PyResult<()>; @@ -302,7 +315,7 @@ impl<'py> PyListMethods<'py> for Bound<'py, PyList> { /// # Safety /// /// Caller must verify that the index is within the bounds of the list. - #[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))] + #[cfg(not(Py_LIMITED_API))] unsafe fn get_item_unchecked(&self, index: usize) -> Bound<'py, PyAny> { // PyList_GET_ITEM return borrowed ptr; must make owned for safety (see #890). ffi::PyList_GET_ITEM(self.as_ptr(), index as Py_ssize_t) @@ -440,6 +453,14 @@ impl<'py> PyListMethods<'py> for Bound<'py, PyList> { BoundListIterator::new(self.clone()) } + /// Iterates over a list while holding a critical section, calling a closure on each item + fn locked_for_each(&self, closure: F) -> PyResult<()> + where + F: Fn(Bound<'py, PyAny>) -> PyResult<()>, + { + crate::sync::with_critical_section(self, || self.iter().try_for_each(closure)) + } + /// Sorts the list in-place. Equivalent to the Python expression `l.sort()`. fn sort(&self) -> PyResult<()> { err::error_on_minusone(self.py(), unsafe { ffi::PyList_Sort(self.as_ptr()) }) @@ -462,73 +483,332 @@ impl<'py> PyListMethods<'py> for Bound<'py, PyList> { } } +// New types for type checking when using BoundListIterator associated methods, like +// BoundListIterator::next_unchecked. +struct Index(usize); +struct Length(usize); + /// Used by `PyList::iter()`. pub struct BoundListIterator<'py> { list: Bound<'py, PyList>, - index: usize, - length: usize, + index: Index, + length: Length, } impl<'py> BoundListIterator<'py> { fn new(list: Bound<'py, PyList>) -> Self { - let length: usize = list.len(); - BoundListIterator { + Self { + index: Index(0), + length: Length(list.len()), list, - index: 0, - length, } } - unsafe fn get_item(&self, index: usize) -> Bound<'py, PyAny> { - #[cfg(any(Py_LIMITED_API, PyPy, Py_GIL_DISABLED))] - let item = self.list.get_item(index).expect("list.get failed"); - #[cfg(not(any(Py_LIMITED_API, PyPy, Py_GIL_DISABLED)))] - let item = self.list.get_item_unchecked(index); - item + /// # Safety + /// + /// On the free-threaded build, caller must verify they have exclusive + /// access to the list by holding a lock or by holding the innermost + /// critical section on the list. + #[inline] + #[cfg(not(Py_LIMITED_API))] + #[deny(unsafe_op_in_unsafe_fn)] + unsafe fn next_unchecked( + index: &mut Index, + length: &mut Length, + list: &Bound<'py, PyList>, + ) -> Option> { + let length = length.0.min(list.len()); + let my_index = index.0; + + if index.0 < length { + let item = unsafe { list.get_item_unchecked(my_index) }; + index.0 += 1; + Some(item) + } else { + None + } } -} -impl<'py> Iterator for BoundListIterator<'py> { - type Item = Bound<'py, PyAny>; + #[cfg(Py_LIMITED_API)] + fn next( + index: &mut Index, + length: &mut Length, + list: &Bound<'py, PyList>, + ) -> Option> { + let length = length.0.min(list.len()); + let my_index = index.0; + if index.0 < length { + let item = list.get_item(my_index).expect("get-item failed"); + index.0 += 1; + Some(item) + } else { + None + } + } + + /// # Safety + /// + /// On the free-threaded build, caller must verify they have exclusive + /// access to the list by holding a lock or by holding the innermost + /// critical section on the list. #[inline] - fn next(&mut self) -> Option { - let length = self.length.min(self.list.len()); + #[cfg(not(Py_LIMITED_API))] + #[deny(unsafe_op_in_unsafe_fn)] + unsafe fn next_back_unchecked( + index: &mut Index, + length: &mut Length, + list: &Bound<'py, PyList>, + ) -> Option> { + let current_length = length.0.min(list.len()); + + if index.0 < current_length { + let item = unsafe { list.get_item_unchecked(current_length - 1) }; + length.0 = current_length - 1; + Some(item) + } else { + None + } + } - if self.index < length { - let item = unsafe { self.get_item(self.index) }; - self.index += 1; + #[inline] + #[cfg(Py_LIMITED_API)] + fn next_back( + index: &mut Index, + length: &mut Length, + list: &Bound<'py, PyList>, + ) -> Option> { + let current_length = (length.0).min(list.len()); + + if index.0 < current_length { + let item = list.get_item(current_length - 1).expect("get-item failed"); + length.0 = current_length - 1; Some(item) } else { None } } + #[cfg(not(Py_LIMITED_API))] + fn with_critical_section( + &mut self, + f: impl FnOnce(&mut Index, &mut Length, &Bound<'py, PyList>) -> R, + ) -> R { + let Self { + index, + length, + list, + } = self; + crate::sync::with_critical_section(list, || f(index, length, list)) + } +} + +impl<'py> Iterator for BoundListIterator<'py> { + type Item = Bound<'py, PyAny>; + + #[inline] + fn next(&mut self) -> Option { + #[cfg(not(Py_LIMITED_API))] + { + self.with_critical_section(|index, length, list| unsafe { + Self::next_unchecked(index, length, list) + }) + } + #[cfg(Py_LIMITED_API)] + { + let Self { + index, + length, + list, + } = self; + Self::next(index, length, list) + } + } + #[inline] fn size_hint(&self) -> (usize, Option) { let len = self.len(); (len, Some(len)) } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn fold(mut self, init: B, mut f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + self.with_critical_section(|index, length, list| { + let mut accum = init; + while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } { + accum = f(accum, x); + } + accum + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, feature = "nightly"))] + fn try_fold(&mut self, init: B, mut f: F) -> R + where + Self: Sized, + F: FnMut(B, Self::Item) -> R, + R: std::ops::Try, + { + self.with_critical_section(|index, length, list| { + let mut accum = init; + while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } { + accum = f(accum, x)? + } + R::from_output(accum) + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn all(&mut self, mut f: F) -> bool + where + Self: Sized, + F: FnMut(Self::Item) -> bool, + { + self.with_critical_section(|index, length, list| { + while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } { + if !f(x) { + return false; + } + } + true + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn any(&mut self, mut f: F) -> bool + where + Self: Sized, + F: FnMut(Self::Item) -> bool, + { + self.with_critical_section(|index, length, list| { + while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } { + if f(x) { + return true; + } + } + false + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn find

(&mut self, mut predicate: P) -> Option + where + Self: Sized, + P: FnMut(&Self::Item) -> bool, + { + self.with_critical_section(|index, length, list| { + while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } { + if predicate(&x) { + return Some(x); + } + } + None + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn find_map(&mut self, mut f: F) -> Option + where + Self: Sized, + F: FnMut(Self::Item) -> Option, + { + self.with_critical_section(|index, length, list| { + while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } { + if let found @ Some(_) = f(x) { + return found; + } + } + None + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn position

(&mut self, mut predicate: P) -> Option + where + Self: Sized, + P: FnMut(Self::Item) -> bool, + { + self.with_critical_section(|index, length, list| { + let mut acc = 0; + while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } { + if predicate(x) { + return Some(acc); + } + acc += 1; + } + None + }) + } } impl DoubleEndedIterator for BoundListIterator<'_> { #[inline] fn next_back(&mut self) -> Option { - let length = self.length.min(self.list.len()); - - if self.index < length { - let item = unsafe { self.get_item(length - 1) }; - self.length = length - 1; - Some(item) - } else { - None + #[cfg(not(Py_LIMITED_API))] + { + self.with_critical_section(|index, length, list| unsafe { + Self::next_back_unchecked(index, length, list) + }) + } + #[cfg(Py_LIMITED_API)] + { + let Self { + index, + length, + list, + } = self; + Self::next_back(index, length, list) } } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn rfold(mut self, init: B, mut f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + self.with_critical_section(|index, length, list| { + let mut accum = init; + while let Some(x) = unsafe { Self::next_back_unchecked(index, length, list) } { + accum = f(accum, x); + } + accum + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, feature = "nightly"))] + fn try_rfold(&mut self, init: B, mut f: F) -> R + where + Self: Sized, + F: FnMut(B, Self::Item) -> R, + R: std::ops::Try, + { + self.with_critical_section(|index, length, list| { + let mut accum = init; + while let Some(x) = unsafe { Self::next_back_unchecked(index, length, list) } { + accum = f(accum, x)? + } + R::from_output(accum) + }) + } } impl ExactSizeIterator for BoundListIterator<'_> { fn len(&self) -> usize { - self.length.saturating_sub(self.index) + self.length.0.saturating_sub(self.index.0) } } @@ -558,7 +838,7 @@ mod tests { use crate::types::list::PyListMethods; use crate::types::sequence::PySequenceMethods; use crate::types::{PyList, PyTuple}; - use crate::{ffi, IntoPyObject, Python}; + use crate::{ffi, IntoPyObject, PyResult, Python}; #[test] fn test_new() { @@ -748,6 +1028,142 @@ mod tests { }); } + #[test] + fn test_iter_all() { + Python::with_gil(|py| { + let list = PyList::new(py, [true, true, true]).unwrap(); + assert!(list.iter().all(|x| x.extract::().unwrap())); + + let list = PyList::new(py, [true, false, true]).unwrap(); + assert!(!list.iter().all(|x| x.extract::().unwrap())); + }); + } + + #[test] + fn test_iter_any() { + Python::with_gil(|py| { + let list = PyList::new(py, [true, true, true]).unwrap(); + assert!(list.iter().any(|x| x.extract::().unwrap())); + + let list = PyList::new(py, [true, false, true]).unwrap(); + assert!(list.iter().any(|x| x.extract::().unwrap())); + + let list = PyList::new(py, [false, false, false]).unwrap(); + assert!(!list.iter().any(|x| x.extract::().unwrap())); + }); + } + + #[test] + fn test_iter_find() { + Python::with_gil(|py: Python<'_>| { + let list = PyList::new(py, ["hello", "world"]).unwrap(); + assert_eq!( + Some("world".to_string()), + list.iter() + .find(|v| v.extract::().unwrap() == "world") + .map(|v| v.extract::().unwrap()) + ); + assert_eq!( + None, + list.iter() + .find(|v| v.extract::().unwrap() == "foobar") + .map(|v| v.extract::().unwrap()) + ); + }); + } + + #[test] + fn test_iter_position() { + Python::with_gil(|py: Python<'_>| { + let list = PyList::new(py, ["hello", "world"]).unwrap(); + assert_eq!( + Some(1), + list.iter() + .position(|v| v.extract::().unwrap() == "world") + ); + assert_eq!( + None, + list.iter() + .position(|v| v.extract::().unwrap() == "foobar") + ); + }); + } + + #[test] + fn test_iter_fold() { + Python::with_gil(|py: Python<'_>| { + let list = PyList::new(py, [1, 2, 3]).unwrap(); + let sum = list + .iter() + .fold(0, |acc, v| acc + v.extract::().unwrap()); + assert_eq!(sum, 6); + }); + } + + #[test] + fn test_iter_fold_out_of_bounds() { + Python::with_gil(|py: Python<'_>| { + let list = PyList::new(py, [1, 2, 3]).unwrap(); + let sum = list.iter().fold(0, |_, _| { + // clear the list to create a pathological fold operation + // that mutates the list as it processes it + for _ in 0..3 { + list.del_item(0).unwrap(); + } + -5 + }); + assert_eq!(sum, -5); + assert!(list.len() == 0); + }); + } + + #[test] + fn test_iter_rfold() { + Python::with_gil(|py: Python<'_>| { + let list = PyList::new(py, [1, 2, 3]).unwrap(); + let sum = list + .iter() + .rfold(0, |acc, v| acc + v.extract::().unwrap()); + assert_eq!(sum, 6); + }); + } + + #[test] + fn test_iter_try_fold() { + Python::with_gil(|py: Python<'_>| { + let list = PyList::new(py, [1, 2, 3]).unwrap(); + let sum = list + .iter() + .try_fold(0, |acc, v| PyResult::Ok(acc + v.extract::()?)) + .unwrap(); + assert_eq!(sum, 6); + + let list = PyList::new(py, ["foo", "bar"]).unwrap(); + assert!(list + .iter() + .try_fold(0, |acc, v| PyResult::Ok(acc + v.extract::()?)) + .is_err()); + }); + } + + #[test] + fn test_iter_try_rfold() { + Python::with_gil(|py: Python<'_>| { + let list = PyList::new(py, [1, 2, 3]).unwrap(); + let sum = list + .iter() + .try_rfold(0, |acc, v| PyResult::Ok(acc + v.extract::()?)) + .unwrap(); + assert_eq!(sum, 6); + + let list = PyList::new(py, ["foo", "bar"]).unwrap(); + assert!(list + .iter() + .try_rfold(0, |acc, v| PyResult::Ok(acc + v.extract::()?)) + .is_err()); + }); + } + #[test] fn test_into_iter() { Python::with_gil(|py| { @@ -877,7 +1293,7 @@ mod tests { }); } - #[cfg(not(any(Py_LIMITED_API, PyPy, Py_GIL_DISABLED)))] + #[cfg(not(Py_LIMITED_API))] #[test] fn test_list_get_item_unchecked_sanity() { Python::with_gil(|py| {