Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor PyErr state to reduce blast radius of threading challenge #4650

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 132 additions & 55 deletions src/err/err_state.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,109 @@
use std::cell::UnsafeCell;

use crate::{
exceptions::{PyBaseException, PyTypeError},
ffi,
types::{PyTraceback, PyType},
Bound, IntoPy, Py, PyAny, PyObject, PyTypeInfo, Python,
Bound, Py, PyAny, PyErrArguments, PyObject, PyTypeInfo, Python,
};

pub(crate) struct PyErrState {
// Safety: can only hand out references when in the "normalized" state. Will never change
// after normalization.
//
// The state is temporarily removed from the PyErr during normalization, to avoid
// concurrent modifications.
inner: UnsafeCell<Option<PyErrStateInner>>,
}

// The inner value is only accessed through ways that require the gil is held.
unsafe impl Send for PyErrState {}
unsafe impl Sync for PyErrState {}

impl PyErrState {
pub(crate) fn lazy(f: Box<PyErrStateLazyFn>) -> Self {
Self::from_inner(PyErrStateInner::Lazy(f))
}

pub(crate) fn lazy_arguments(ptype: Py<PyAny>, args: impl PyErrArguments + 'static) -> Self {
Self::from_inner(PyErrStateInner::Lazy(Box::new(move |py| {
PyErrStateLazyFnOutput {
ptype,
pvalue: args.arguments(py),
}
})))
}

#[cfg(not(Py_3_12))]
pub(crate) fn ffi_tuple(
ptype: PyObject,
pvalue: Option<PyObject>,
ptraceback: Option<PyObject>,
) -> Self {
Self::from_inner(PyErrStateInner::FfiTuple {
ptype,
pvalue,
ptraceback,
})
}

pub(crate) fn normalized(normalized: PyErrStateNormalized) -> Self {
Self::from_inner(PyErrStateInner::Normalized(normalized))
}

pub(crate) fn restore(self, py: Python<'_>) {
self.inner
.into_inner()
.expect("PyErr state should never be invalid outside of normalization")
.restore(py)
}

fn from_inner(inner: PyErrStateInner) -> Self {
Self {
inner: UnsafeCell::new(Some(inner)),
}
}

#[inline]
pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
if let Some(PyErrStateInner::Normalized(n)) = unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
return n;
}

self.make_normalized(py)
}

#[cold]
fn make_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
// This process is safe because:
// - Access is guaranteed not to be concurrent thanks to `Python` GIL token
// - Write happens only once, and then never will change again.
// - State is set to None during the normalization process, so that a second
// concurrent normalization attempt will panic before changing anything.

// FIXME: this needs to be rewritten to deal with free-threaded Python
// see https://github.com/PyO3/pyo3/issues/4584

let state = unsafe {
(*self.inner.get())
.take()
.expect("Cannot normalize a PyErr while already normalizing it.")
};

unsafe {
let self_state = &mut *self.inner.get();
*self_state = Some(PyErrStateInner::Normalized(state.normalize(py)));
match self_state {
Some(PyErrStateInner::Normalized(n)) => n,
_ => unreachable!(),
}
}
}
}

pub(crate) struct PyErrStateNormalized {
#[cfg(not(Py_3_12))]
ptype: Py<PyType>,
Expand All @@ -14,6 +113,24 @@ pub(crate) struct PyErrStateNormalized {
}

impl PyErrStateNormalized {
pub(crate) fn new(pvalue: Bound<'_, PyBaseException>) -> Self {
#[cfg(not(Py_3_12))]
use crate::types::any::PyAnyMethods;

Self {
#[cfg(not(Py_3_12))]
ptype: pvalue.get_type().into(),
#[cfg(not(Py_3_12))]
ptraceback: unsafe {
Py::from_owned_ptr_or_opt(
pvalue.py(),
ffi::PyException_GetTraceback(pvalue.as_ptr()),
)
},
pvalue: pvalue.into(),
}
}

#[cfg(not(Py_3_12))]
pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
self.ptype.bind(py).clone()
Expand Down Expand Up @@ -85,7 +202,7 @@ pub(crate) struct PyErrStateLazyFnOutput {
pub(crate) type PyErrStateLazyFn =
dyn for<'py> FnOnce(Python<'py>) -> PyErrStateLazyFnOutput + Send + Sync;

pub(crate) enum PyErrState {
enum PyErrStateInner {
Lazy(Box<PyErrStateLazyFn>),
#[cfg(not(Py_3_12))]
FfiTuple {
Expand All @@ -96,66 +213,26 @@ pub(crate) enum PyErrState {
Normalized(PyErrStateNormalized),
}

/// Helper conversion trait that allows to use custom arguments for lazy exception construction.
pub trait PyErrArguments: Send + Sync {
/// Arguments for exception
fn arguments(self, py: Python<'_>) -> PyObject;
}

impl<T> PyErrArguments for T
where
T: IntoPy<PyObject> + Send + Sync,
{
fn arguments(self, py: Python<'_>) -> PyObject {
self.into_py(py)
}
}

impl PyErrState {
pub(crate) fn lazy(ptype: Py<PyAny>, args: impl PyErrArguments + 'static) -> Self {
PyErrState::Lazy(Box::new(move |py| PyErrStateLazyFnOutput {
ptype,
pvalue: args.arguments(py),
}))
}

pub(crate) fn normalized(pvalue: Bound<'_, PyBaseException>) -> Self {
#[cfg(not(Py_3_12))]
use crate::types::any::PyAnyMethods;

Self::Normalized(PyErrStateNormalized {
#[cfg(not(Py_3_12))]
ptype: pvalue.get_type().into(),
#[cfg(not(Py_3_12))]
ptraceback: unsafe {
Py::from_owned_ptr_or_opt(
pvalue.py(),
ffi::PyException_GetTraceback(pvalue.as_ptr()),
)
},
pvalue: pvalue.into(),
})
}

pub(crate) fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
impl PyErrStateInner {
fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
match self {
#[cfg(not(Py_3_12))]
PyErrState::Lazy(lazy) => {
PyErrStateInner::Lazy(lazy) => {
let (ptype, pvalue, ptraceback) = lazy_into_normalized_ffi_tuple(py, lazy);
unsafe {
PyErrStateNormalized::from_normalized_ffi_tuple(py, ptype, pvalue, ptraceback)
}
}
#[cfg(Py_3_12)]
PyErrState::Lazy(lazy) => {
PyErrStateInner::Lazy(lazy) => {
// To keep the implementation simple, just write the exception into the interpreter,
// which will cause it to be normalized
raise_lazy(py, lazy);
PyErrStateNormalized::take(py)
.expect("exception missing after writing to the interpreter")
}
#[cfg(not(Py_3_12))]
PyErrState::FfiTuple {
PyErrStateInner::FfiTuple {
ptype,
pvalue,
ptraceback,
Expand All @@ -168,15 +245,15 @@ impl PyErrState {
PyErrStateNormalized::from_normalized_ffi_tuple(py, ptype, pvalue, ptraceback)
}
}
PyErrState::Normalized(normalized) => normalized,
PyErrStateInner::Normalized(normalized) => normalized,
}
}

#[cfg(not(Py_3_12))]
pub(crate) fn restore(self, py: Python<'_>) {
fn restore(self, py: Python<'_>) {
let (ptype, pvalue, ptraceback) = match self {
PyErrState::Lazy(lazy) => lazy_into_normalized_ffi_tuple(py, lazy),
PyErrState::FfiTuple {
PyErrStateInner::Lazy(lazy) => lazy_into_normalized_ffi_tuple(py, lazy),
PyErrStateInner::FfiTuple {
ptype,
pvalue,
ptraceback,
Expand All @@ -185,7 +262,7 @@ impl PyErrState {
pvalue.map_or(std::ptr::null_mut(), Py::into_ptr),
ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr),
),
PyErrState::Normalized(PyErrStateNormalized {
PyErrStateInner::Normalized(PyErrStateNormalized {
ptype,
pvalue,
ptraceback,
Expand All @@ -199,10 +276,10 @@ impl PyErrState {
}

#[cfg(Py_3_12)]
pub(crate) fn restore(self, py: Python<'_>) {
fn restore(self, py: Python<'_>) {
match self {
PyErrState::Lazy(lazy) => raise_lazy(py, lazy),
PyErrState::Normalized(PyErrStateNormalized { pvalue }) => unsafe {
PyErrStateInner::Lazy(lazy) => raise_lazy(py, lazy),
PyErrStateInner::Normalized(PyErrStateNormalized { pvalue }) => unsafe {
ffi::PyErr_SetRaisedException(pvalue.into_ptr())
},
}
Expand Down
Loading
Loading