From 5266a72d63a1c77321631527921fbbdd580bd72d Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Fri, 25 Oct 2024 22:06:51 +0100 Subject: [PATCH] refactor `PyErr` state to reduce blast radius of threading challenge (#4650) * refactor `PyErr` state to reduce blast radius of threading challenge * restrict visibility --- src/err/err_state.rs | 187 ++++++++++++++++++++++++++++++------------- src/err/mod.rs | 96 +++++++--------------- 2 files changed, 162 insertions(+), 121 deletions(-) diff --git a/src/err/err_state.rs b/src/err/err_state.rs index dc07294a0fa..58303b46f32 100644 --- a/src/err/err_state.rs +++ b/src/err/err_state.rs @@ -1,10 +1,109 @@ +use std::cell::UnsafeCell; + use crate::{ exceptions::{PyBaseException, PyTypeError}, ffi, types::{PyTraceback, PyType}, - Bound, IntoPy, Py, PyAny, PyObject, PyTypeInfo, Python, + Bound, Py, PyAny, PyErrArguments, PyObject, PyTypeInfo, Python, }; +pub(crate) struct PyErrState { + // Safety: can only hand out references when in the "normalized" state. Will never change + // after normalization. + // + // The state is temporarily removed from the PyErr during normalization, to avoid + // concurrent modifications. + inner: UnsafeCell>, +} + +// The inner value is only accessed through ways that require the gil is held. +unsafe impl Send for PyErrState {} +unsafe impl Sync for PyErrState {} + +impl PyErrState { + pub(crate) fn lazy(f: Box) -> Self { + Self::from_inner(PyErrStateInner::Lazy(f)) + } + + pub(crate) fn lazy_arguments(ptype: Py, args: impl PyErrArguments + 'static) -> Self { + Self::from_inner(PyErrStateInner::Lazy(Box::new(move |py| { + PyErrStateLazyFnOutput { + ptype, + pvalue: args.arguments(py), + } + }))) + } + + #[cfg(not(Py_3_12))] + pub(crate) fn ffi_tuple( + ptype: PyObject, + pvalue: Option, + ptraceback: Option, + ) -> Self { + Self::from_inner(PyErrStateInner::FfiTuple { + ptype, + pvalue, + ptraceback, + }) + } + + pub(crate) fn normalized(normalized: PyErrStateNormalized) -> Self { + Self::from_inner(PyErrStateInner::Normalized(normalized)) + } + + pub(crate) fn restore(self, py: Python<'_>) { + self.inner + .into_inner() + .expect("PyErr state should never be invalid outside of normalization") + .restore(py) + } + + fn from_inner(inner: PyErrStateInner) -> Self { + Self { + inner: UnsafeCell::new(Some(inner)), + } + } + + #[inline] + pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized { + if let Some(PyErrStateInner::Normalized(n)) = unsafe { + // Safety: self.inner will never be written again once normalized. + &*self.inner.get() + } { + return n; + } + + self.make_normalized(py) + } + + #[cold] + fn make_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized { + // This process is safe because: + // - Access is guaranteed not to be concurrent thanks to `Python` GIL token + // - Write happens only once, and then never will change again. + // - State is set to None during the normalization process, so that a second + // concurrent normalization attempt will panic before changing anything. + + // FIXME: this needs to be rewritten to deal with free-threaded Python + // see https://github.com/PyO3/pyo3/issues/4584 + + let state = unsafe { + (*self.inner.get()) + .take() + .expect("Cannot normalize a PyErr while already normalizing it.") + }; + + unsafe { + let self_state = &mut *self.inner.get(); + *self_state = Some(PyErrStateInner::Normalized(state.normalize(py))); + match self_state { + Some(PyErrStateInner::Normalized(n)) => n, + _ => unreachable!(), + } + } + } +} + pub(crate) struct PyErrStateNormalized { #[cfg(not(Py_3_12))] ptype: Py, @@ -14,6 +113,24 @@ pub(crate) struct PyErrStateNormalized { } impl PyErrStateNormalized { + pub(crate) fn new(pvalue: Bound<'_, PyBaseException>) -> Self { + #[cfg(not(Py_3_12))] + use crate::types::any::PyAnyMethods; + + Self { + #[cfg(not(Py_3_12))] + ptype: pvalue.get_type().into(), + #[cfg(not(Py_3_12))] + ptraceback: unsafe { + Py::from_owned_ptr_or_opt( + pvalue.py(), + ffi::PyException_GetTraceback(pvalue.as_ptr()), + ) + }, + pvalue: pvalue.into(), + } + } + #[cfg(not(Py_3_12))] pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> { self.ptype.bind(py).clone() @@ -85,7 +202,7 @@ pub(crate) struct PyErrStateLazyFnOutput { pub(crate) type PyErrStateLazyFn = dyn for<'py> FnOnce(Python<'py>) -> PyErrStateLazyFnOutput + Send + Sync; -pub(crate) enum PyErrState { +enum PyErrStateInner { Lazy(Box), #[cfg(not(Py_3_12))] FfiTuple { @@ -96,58 +213,18 @@ pub(crate) enum PyErrState { Normalized(PyErrStateNormalized), } -/// Helper conversion trait that allows to use custom arguments for lazy exception construction. -pub trait PyErrArguments: Send + Sync { - /// Arguments for exception - fn arguments(self, py: Python<'_>) -> PyObject; -} - -impl PyErrArguments for T -where - T: IntoPy + Send + Sync, -{ - fn arguments(self, py: Python<'_>) -> PyObject { - self.into_py(py) - } -} - -impl PyErrState { - pub(crate) fn lazy(ptype: Py, args: impl PyErrArguments + 'static) -> Self { - PyErrState::Lazy(Box::new(move |py| PyErrStateLazyFnOutput { - ptype, - pvalue: args.arguments(py), - })) - } - - pub(crate) fn normalized(pvalue: Bound<'_, PyBaseException>) -> Self { - #[cfg(not(Py_3_12))] - use crate::types::any::PyAnyMethods; - - Self::Normalized(PyErrStateNormalized { - #[cfg(not(Py_3_12))] - ptype: pvalue.get_type().into(), - #[cfg(not(Py_3_12))] - ptraceback: unsafe { - Py::from_owned_ptr_or_opt( - pvalue.py(), - ffi::PyException_GetTraceback(pvalue.as_ptr()), - ) - }, - pvalue: pvalue.into(), - }) - } - - pub(crate) fn normalize(self, py: Python<'_>) -> PyErrStateNormalized { +impl PyErrStateInner { + fn normalize(self, py: Python<'_>) -> PyErrStateNormalized { match self { #[cfg(not(Py_3_12))] - PyErrState::Lazy(lazy) => { + PyErrStateInner::Lazy(lazy) => { let (ptype, pvalue, ptraceback) = lazy_into_normalized_ffi_tuple(py, lazy); unsafe { PyErrStateNormalized::from_normalized_ffi_tuple(py, ptype, pvalue, ptraceback) } } #[cfg(Py_3_12)] - PyErrState::Lazy(lazy) => { + PyErrStateInner::Lazy(lazy) => { // To keep the implementation simple, just write the exception into the interpreter, // which will cause it to be normalized raise_lazy(py, lazy); @@ -155,7 +232,7 @@ impl PyErrState { .expect("exception missing after writing to the interpreter") } #[cfg(not(Py_3_12))] - PyErrState::FfiTuple { + PyErrStateInner::FfiTuple { ptype, pvalue, ptraceback, @@ -168,15 +245,15 @@ impl PyErrState { PyErrStateNormalized::from_normalized_ffi_tuple(py, ptype, pvalue, ptraceback) } } - PyErrState::Normalized(normalized) => normalized, + PyErrStateInner::Normalized(normalized) => normalized, } } #[cfg(not(Py_3_12))] - pub(crate) fn restore(self, py: Python<'_>) { + fn restore(self, py: Python<'_>) { let (ptype, pvalue, ptraceback) = match self { - PyErrState::Lazy(lazy) => lazy_into_normalized_ffi_tuple(py, lazy), - PyErrState::FfiTuple { + PyErrStateInner::Lazy(lazy) => lazy_into_normalized_ffi_tuple(py, lazy), + PyErrStateInner::FfiTuple { ptype, pvalue, ptraceback, @@ -185,7 +262,7 @@ impl PyErrState { pvalue.map_or(std::ptr::null_mut(), Py::into_ptr), ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr), ), - PyErrState::Normalized(PyErrStateNormalized { + PyErrStateInner::Normalized(PyErrStateNormalized { ptype, pvalue, ptraceback, @@ -199,10 +276,10 @@ impl PyErrState { } #[cfg(Py_3_12)] - pub(crate) fn restore(self, py: Python<'_>) { + fn restore(self, py: Python<'_>) { match self { - PyErrState::Lazy(lazy) => raise_lazy(py, lazy), - PyErrState::Normalized(PyErrStateNormalized { pvalue }) => unsafe { + PyErrStateInner::Lazy(lazy) => raise_lazy(py, lazy), + PyErrStateInner::Normalized(PyErrStateNormalized { pvalue }) => unsafe { ffi::PyErr_SetRaisedException(pvalue.into_ptr()) }, } diff --git a/src/err/mod.rs b/src/err/mod.rs index 59d99f72c06..14c368938f1 100644 --- a/src/err/mod.rs +++ b/src/err/mod.rs @@ -11,14 +11,12 @@ use crate::{ }; use crate::{Borrowed, BoundObject, IntoPy, Py, PyAny, PyObject, Python}; use std::borrow::Cow; -use std::cell::UnsafeCell; use std::ffi::{CStr, CString}; mod err_state; mod impls; use crate::conversion::IntoPyObject; -pub use err_state::PyErrArguments; use err_state::{PyErrState, PyErrStateLazyFnOutput, PyErrStateNormalized}; use std::convert::Infallible; @@ -32,19 +30,12 @@ use std::convert::Infallible; /// [`get_type_bound`](PyErr::get_type_bound), or [`is_instance_bound`](PyErr::is_instance_bound) /// will create the full exception object if it was not already created. pub struct PyErr { - // Safety: can only hand out references when in the "normalized" state. Will never change - // after normalization. - // - // The state is temporarily removed from the PyErr during normalization, to avoid - // concurrent modifications. - state: UnsafeCell>, + state: PyErrState, } // The inner value is only accessed through ways that require proving the gil is held #[cfg(feature = "nightly")] unsafe impl crate::marker::Ungil for PyErr {} -unsafe impl Send for PyErr {} -unsafe impl Sync for PyErr {} /// Represents the result of a Python call. pub type PyResult = Result; @@ -102,6 +93,21 @@ impl<'py> DowncastIntoError<'py> { } } +/// Helper conversion trait that allows to use custom arguments for lazy exception construction. +pub trait PyErrArguments: Send + Sync { + /// Arguments for exception + fn arguments(self, py: Python<'_>) -> PyObject; +} + +impl PyErrArguments for T +where + T: IntoPy + Send + Sync, +{ + fn arguments(self, py: Python<'_>) -> PyObject { + self.into_py(py) + } +} + impl PyErr { /// Creates a new PyErr of type `T`. /// @@ -160,7 +166,7 @@ impl PyErr { T: PyTypeInfo, A: PyErrArguments + Send + Sync + 'static, { - PyErr::from_state(PyErrState::Lazy(Box::new(move |py| { + PyErr::from_state(PyErrState::lazy(Box::new(move |py| { PyErrStateLazyFnOutput { ptype: T::type_object(py).into(), pvalue: args.arguments(py), @@ -182,7 +188,7 @@ impl PyErr { where A: PyErrArguments + Send + Sync + 'static, { - PyErr::from_state(PyErrState::lazy(ty.unbind().into_any(), args)) + PyErr::from_state(PyErrState::lazy_arguments(ty.unbind().into_any(), args)) } /// Deprecated name for [`PyErr::from_type`]. @@ -230,13 +236,13 @@ impl PyErr { /// ``` pub fn from_value(obj: Bound<'_, PyAny>) -> PyErr { let state = match obj.downcast_into::() { - Ok(obj) => PyErrState::normalized(obj), + Ok(obj) => PyErrState::normalized(PyErrStateNormalized::new(obj)), Err(err) => { // Assume obj is Type[Exception]; let later normalization handle if this // is not the case let obj = err.into_inner(); let py = obj.py(); - PyErrState::lazy(obj.into_py(py), py.None()) + PyErrState::lazy_arguments(obj.into_py(py), py.None()) } }; @@ -392,19 +398,13 @@ impl PyErr { .map(|py_str| py_str.to_string_lossy().into()) .unwrap_or_else(|| String::from("Unwrapped panic from Python code")); - let state = PyErrState::FfiTuple { - ptype, - pvalue, - ptraceback, - }; + let state = PyErrState::ffi_tuple(ptype, pvalue, ptraceback); Self::print_panic_and_unwind(py, state, msg) } - Some(PyErr::from_state(PyErrState::FfiTuple { - ptype, - pvalue, - ptraceback, - })) + Some(PyErr::from_state(PyErrState::ffi_tuple( + ptype, pvalue, ptraceback, + ))) } #[cfg(Py_3_12)] @@ -416,10 +416,10 @@ impl PyErr { .str() .map(|py_str| py_str.to_string_lossy().into()) .unwrap_or_else(|_| String::from("Unwrapped panic from Python code")); - Self::print_panic_and_unwind(py, PyErrState::Normalized(state), msg) + Self::print_panic_and_unwind(py, PyErrState::normalized(state), msg) } - Some(PyErr::from_state(PyErrState::Normalized(state))) + Some(PyErr::from_state(PyErrState::normalized(state))) } fn print_panic_and_unwind(py: Python<'_>, state: PyErrState, msg: String) -> ! { @@ -596,10 +596,7 @@ impl PyErr { /// This is the opposite of `PyErr::fetch()`. #[inline] pub fn restore(self, py: Python<'_>) { - self.state - .into_inner() - .expect("PyErr state should never be invalid outside of normalization") - .restore(py) + self.state.restore(py) } /// Reports the error as unraisable. @@ -774,7 +771,7 @@ impl PyErr { /// ``` #[inline] pub fn clone_ref(&self, py: Python<'_>) -> PyErr { - PyErr::from_state(PyErrState::Normalized(self.normalized(py).clone_ref(py))) + PyErr::from_state(PyErrState::normalized(self.normalized(py).clone_ref(py))) } /// Return the cause (either an exception instance, or None, set by `raise ... from ...`) @@ -808,45 +805,12 @@ impl PyErr { #[inline] fn from_state(state: PyErrState) -> PyErr { - PyErr { - state: UnsafeCell::new(Some(state)), - } + PyErr { state } } #[inline] fn normalized(&self, py: Python<'_>) -> &PyErrStateNormalized { - if let Some(PyErrState::Normalized(n)) = unsafe { - // Safety: self.state will never be written again once normalized. - &*self.state.get() - } { - return n; - } - - self.make_normalized(py) - } - - #[cold] - fn make_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized { - // This process is safe because: - // - Access is guaranteed not to be concurrent thanks to `Python` GIL token - // - Write happens only once, and then never will change again. - // - State is set to None during the normalization process, so that a second - // concurrent normalization attempt will panic before changing anything. - - let state = unsafe { - (*self.state.get()) - .take() - .expect("Cannot normalize a PyErr while already normalizing it.") - }; - - unsafe { - let self_state = &mut *self.state.get(); - *self_state = Some(PyErrState::Normalized(state.normalize(py))); - match self_state { - Some(PyErrState::Normalized(n)) => n, - _ => unreachable!(), - } - } + self.state.as_normalized(py) } }