diff --git a/Cargo.toml b/Cargo.toml index a9a26811e32..2f8aff77efa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ unindent = { version = "0.2.1", optional = true } inventory = { version = "0.3.0", optional = true } # coroutine implementation -futures-task = "0.3" +futures= "0.3" # crate integrations that can be added using the eponymous features anyhow = { version = "1.0", optional = true } @@ -58,7 +58,6 @@ serde_json = "1.0.61" rayon = "1.0.2" rust_decimal = { version = "1.8.0", features = ["std"] } widestring = "0.5.1" -futures = "0.3.28" [build-dependencies] pyo3-build-config = { path = "pyo3-build-config", version = "0.20.0", features = ["resolve-config"] } diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 7ea65899850..af027f19d66 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -21,6 +21,7 @@ pub struct FnArg<'a> { pub optional: Option<&'a syn::Type>, pub default: Option, pub py: bool, + pub coroutine_cancel: bool, pub attrs: PyFunctionArgPyO3Attributes, pub is_varargs: bool, pub is_kwargs: bool, @@ -50,6 +51,7 @@ impl<'a> FnArg<'a> { optional: utils::option_type_argument(&cap.ty), default: None, py: utils::is_python(&cap.ty), + coroutine_cancel: utils::is_coroutine_cancel(&cap.ty), attrs: arg_attrs, is_varargs: false, is_kwargs: false, @@ -446,10 +448,27 @@ impl<'a> FnSpec<'a> { let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise); let func_name = &self.name; + let coroutine_cancel = self + .signature + .arguments + .iter() + .find(|arg| arg.coroutine_cancel); + if let (None, Some(arg)) = (&self.asyncness, coroutine_cancel) { + bail_spanned!(arg.ty.span() => "`CoroutineCancel` argument only allowed with `async fn`"); + } + let rust_call = |args: Vec| { let mut call = quote! { function(#self_arg #(#args),*) }; if self.asyncness.is_some() { - call = quote! { _pyo3::impl_::coroutine::wrap_future(#call) }; + call = if coroutine_cancel.is_some() { + quote! {{ + let __coroutine_cancel = _pyo3::coroutine::CoroutineCancel::new(); + let __cancel_handle = __coroutine_cancel.handle(); + _pyo3::impl_::coroutine::wrap_future({ #call }, Some(__cancel_handle)) + }} + } else { + quote! { _pyo3::impl_::coroutine::wrap_future(#call, None) } + }; } quotes::map_result_into_ptr(quotes::ok_wrap(call)) }; diff --git a/pyo3-macros-backend/src/params.rs b/pyo3-macros-backend/src/params.rs index e511ca754ac..c096e45029a 100644 --- a/pyo3-macros-backend/src/params.rs +++ b/pyo3-macros-backend/src/params.rs @@ -155,6 +155,10 @@ fn impl_arg_param( return Ok(quote! { py }); } + if arg.coroutine_cancel { + return Ok(quote! { __coroutine_cancel }); + } + let name = arg.name; let name_str = name.to_string(); diff --git a/pyo3-macros-backend/src/pyfunction/signature.rs b/pyo3-macros-backend/src/pyfunction/signature.rs index ed3256ad461..d440ad1f381 100644 --- a/pyo3-macros-backend/src/pyfunction/signature.rs +++ b/pyo3-macros-backend/src/pyfunction/signature.rs @@ -361,6 +361,16 @@ impl<'a> FunctionSignature<'a> { // Otherwise try next argument. continue; } + if fn_arg.coroutine_cancel { + // If the user incorrectly tried to include cancel: CoroutineCancel in the + // signature, give a useful error as a hint. + ensure_spanned!( + name != fn_arg.name, + name.span() => "arguments of type `CoroutineCancel` must not be part of the signature" + ); + // Otherwise try next argument. + continue; + } ensure_spanned!( name == fn_arg.name, @@ -411,7 +421,7 @@ impl<'a> FunctionSignature<'a> { } // Ensure no non-py arguments remain - if let Some(arg) = args_iter.find(|arg| !arg.py) { + if let Some(arg) = args_iter.find(|arg| !arg.py && !arg.coroutine_cancel) { bail_spanned!( attribute.kw.span() => format!("missing signature entry for argument `{}`", arg.name) ); @@ -429,7 +439,7 @@ impl<'a> FunctionSignature<'a> { let mut python_signature = PythonSignature::default(); for arg in &arguments { // Python<'_> arguments don't show in Python signature - if arg.py { + if arg.py || arg.coroutine_cancel { continue; } diff --git a/pyo3-macros-backend/src/utils.rs b/pyo3-macros-backend/src/utils.rs index 360b1ec2341..582dbf26430 100644 --- a/pyo3-macros-backend/src/utils.rs +++ b/pyo3-macros-backend/src/utils.rs @@ -41,6 +41,19 @@ pub fn is_python(ty: &syn::Type) -> bool { } } +/// Check if the given type `ty` is `pyo3::coroutine::CoroutineCancel`. +pub fn is_coroutine_cancel(ty: &syn::Type) -> bool { + match unwrap_ty_group(ty) { + syn::Type::Path(typath) => typath + .path + .segments + .last() + .map(|seg| seg.ident == "CoroutineCancel") + .unwrap_or(false), + _ => false, + } +} + /// If `ty` is `Option`, return `Some(T)`, else `None`. pub fn option_type_argument(ty: &syn::Type) -> Option<&syn::Type> { if let syn::Type::Path(syn::TypePath { path, .. }) = ty { diff --git a/pyo3-macros/src/lib.rs b/pyo3-macros/src/lib.rs index 37c7e6e9b99..52a4275c956 100644 --- a/pyo3-macros/src/lib.rs +++ b/pyo3-macros/src/lib.rs @@ -127,6 +127,7 @@ pub fn pymethods(attr: TokenStream, input: TokenStream) -> TokenStream { /// | `#[pyo3(name = "...")]` | Defines the name of the function in Python. | /// | `#[pyo3(text_signature = "...")]` | Defines the `__text_signature__` attribute of the function in Python. | /// | `#[pyo3(pass_module)]` | Passes the module containing the function as a `&PyModule` first argument to the function. | +/// /// | `#[pyo3(coroutine_cancel = "...")]` | (`async fn` only) Pass a `CoroutineCancel` instance to the given parameter | /// /// For more on exposing functions see the [function section of the guide][1]. /// diff --git a/src/coroutine.rs b/src/coroutine.rs index 171d5894377..e6bc43eb19e 100644 --- a/src/coroutine.rs +++ b/src/coroutine.rs @@ -1,22 +1,28 @@ //! Python coroutine implementation, used notably when wrapping `async fn` //! with `#[pyfunction]`/`#[pymethods]`. +use crate::coroutine::waker::AsyncioWaker; use crate::exceptions::{PyRuntimeError, PyStopIteration}; use crate::pyclass::IterNextOutput; -use crate::sync::GILOnceCell; -use crate::types::{PyCFunction, PyIterator}; -use crate::{intern, wrap_pyfunction, IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python}; -use pyo3_macros::{pyclass, pyfunction, pymethods}; +use crate::types::PyIterator; +use crate::{IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python}; +use pyo3_macros::{pyclass, pymethods}; use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +mod cancel; +mod waker; + +pub use crate::coroutine::cancel::{CancelHandle, CoroutineCancel}; + const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine"; /// Python coroutine wrapping a [`Future`]. #[pyclass(crate = "crate")] pub struct Coroutine { future: Option> + Send>>>, + cancel: Option, waker: Option>, } @@ -41,6 +47,32 @@ impl Coroutine { }; Self { future: Some(Box::pin(wrap)), + cancel: None, + waker: None, + } + } + + /// Wrap a future into a Python coroutine. + /// + /// Coroutine `send` polls the wrapped future, ignoring the value passed + /// (should always be `None` anyway). + /// + /// Coroutine `throw` registers the exception in `cancel`, and polls the wrapped future + pub fn from_future_with_cancel(future: F, cancel: CancelHandle) -> Self + where + F: Future> + Send + 'static, + T: IntoPy + Send, + E: Send, + PyErr: From, + { + let wrap = async move { + let obj = future.await?; + // SAFETY: GIL is acquired when future is polled (see `Coroutine::poll`) + Ok(obj.into_py(unsafe { Python::assume_gil_acquired() })) + }; + Self { + future: Some(Box::pin(wrap)), + cancel: Some(cancel), waker: None, } } @@ -48,7 +80,7 @@ impl Coroutine { fn poll( &mut self, py: Python<'_>, - throw: Option<&PyAny>, + throw: Option, ) -> PyResult> { // raise if the coroutine has already been run to completion let future_rs = match self.future { @@ -57,8 +89,12 @@ impl Coroutine { }; // reraise thrown exception it if let Some(exc) = throw { - self.close(); - return Err(PyErr::from_value(exc)); + if let Some(ref handle) = self.cancel { + handle.cancel(py, exc) + } else { + self.close(); + return Err(PyErr::from_value(exc.as_ref(py))); + } } // create a new waker, or try to reset it in place if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) { @@ -66,7 +102,7 @@ impl Coroutine { } else { self.waker = Some(Arc::new(AsyncioWaker::new())); } - let waker = futures_task::waker(self.waker.clone().unwrap()); + let waker = futures::task::waker(self.waker.clone().unwrap()); // poll the Rust future and forward its results if ready if let Poll::Ready(res) = future_rs.as_mut().poll(&mut Context::from_waker(&waker)) { self.close(); @@ -101,7 +137,7 @@ impl Coroutine { iter_result(self.poll(py, None)?) } - fn throw(&mut self, py: Python<'_>, exc: &PyAny) -> PyResult { + fn throw(&mut self, py: Python<'_>, exc: PyObject) -> PyResult { iter_result(self.poll(py, Some(exc))?) } @@ -119,93 +155,3 @@ impl Coroutine { self.poll(py, None) } } - -/// Lazy `asyncio.Future` wrapper, implementing [`ArcWake`] by calling `Future.set_result`. -/// -/// asyncio future is let uninitialized until [`initialize_future`][1] is called. -/// If [`wake`][2] is called before future initialization (during Rust future polling), -/// [`initialize_future`][1] will return `None` (it is roughly equivalent to `asyncio.sleep(0)`) -/// -/// [1]: AsyncioWaker::initialize_future -/// [2]: AsyncioWaker::wake -struct AsyncioWaker(GILOnceCell>); - -impl AsyncioWaker { - fn new() -> Self { - Self(GILOnceCell::new()) - } - - fn reset(&mut self) { - self.0.take(); - } - - fn initialize_future<'a>(&'a self, py: Python<'a>) -> PyResult> { - let init = || LoopAndFuture::new(py).map(Some); - let loop_and_future = self.0.get_or_try_init(py, init)?.as_ref(); - Ok(loop_and_future.map(|LoopAndFuture { future, .. }| future.as_ref(py))) - } -} - -impl futures_task::ArcWake for AsyncioWaker { - fn wake_by_ref(arc_self: &Arc) { - Python::with_gil(|gil| { - if let Some(loop_and_future) = arc_self.0.get_or_init(gil, || None) { - loop_and_future - .set_result(gil) - .expect("unexpected error in coroutine waker"); - } - }); - } -} - -struct LoopAndFuture { - event_loop: PyObject, - future: PyObject, -} - -impl LoopAndFuture { - fn new(py: Python<'_>) -> PyResult { - static GET_RUNNING_LOOP: GILOnceCell = GILOnceCell::new(); - let import = || -> PyResult<_> { - let module = py.import("asyncio")?; - Ok(module.getattr("get_running_loop")?.into()) - }; - let event_loop = GET_RUNNING_LOOP.get_or_try_init(py, import)?.call0(py)?; - let future = event_loop.call_method0(py, "create_future")?; - Ok(Self { event_loop, future }) - } - - fn set_result(&self, py: Python<'_>) -> PyResult<()> { - static RELEASE_WAITER: GILOnceCell> = GILOnceCell::new(); - let release_waiter = RELEASE_WAITER - .get_or_try_init(py, || wrap_pyfunction!(release_waiter, py).map(Into::into))?; - // `Future.set_result` must be called in event loop thread, - // so it requires `call_soon_threadsafe` - let call_soon_threadsafe = self.event_loop.call_method1( - py, - intern!(py, "call_soon_threadsafe"), - (release_waiter, self.future.as_ref(py)), - ); - if let Err(err) = call_soon_threadsafe { - // `call_soon_threadsafe` will raise if the event loop is closed; - // instead of catching an unspecific `RuntimeError`, check directly if it's closed. - let is_closed = self.event_loop.call_method0(py, "is_closed")?; - if !is_closed.extract(py)? { - return Err(err); - } - } - Ok(()) - } -} - -/// Call `future.set_result` if the future is not done. -/// -/// Future can be cancelled by the event loop before being waken. -/// See https://github.com/python/cpython/blob/main/Lib/asyncio/tasks.py#L452C5-L452C5 -#[pyfunction(crate = "crate")] -fn release_waiter(future: &PyAny) -> PyResult<()> { - if !future.call_method0("done")?.extract::()? { - future.call_method1("set_result", (future.py().None(),))?; - } - Ok(()) -} diff --git a/src/coroutine/cancel.rs b/src/coroutine/cancel.rs new file mode 100644 index 00000000000..fdfd24c6bc2 --- /dev/null +++ b/src/coroutine/cancel.rs @@ -0,0 +1,68 @@ +use crate::{ffi, Py, PyObject, Python}; +use futures::future::poll_fn; +use futures::task::AtomicWaker; +use std::ptr; +use std::sync::atomic::{AtomicPtr, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; + +#[derive(Debug, Default)] +struct Inner { + exception: AtomicPtr, + waker: AtomicWaker, +} + +/// Helper used to wait and retrieve exception thrown in coroutine. +#[derive(Debug, Default)] +pub struct CoroutineCancel(Arc); + +impl CoroutineCancel { + /// Create a new `CoroutineCancel`. + pub fn new() -> Self { + Default::default() + } + + /// Return an associated [`CancelHandle`]. + pub fn handle(&self) -> CancelHandle { + CancelHandle(self.0.clone()) + } + + fn take_exception(&self) -> PyObject { + let ptr = self.0.exception.swap(ptr::null_mut(), Ordering::Relaxed); + Python::with_gil(|gil| unsafe { Py::from_owned_ptr(gil, ptr) }) + } + + /// Returns whether the associated coroutine has been cancelled. + pub fn is_cancelled(&self) -> bool { + !self.0.exception.load(Ordering::Relaxed).is_null() + } + + /// Poll to retrieve the exception thrown in the associated coroutine. + pub fn poll_cancelled(&mut self, cx: &mut Context<'_>) -> Poll { + if self.is_cancelled() { + return Poll::Ready(self.take_exception()); + } + self.0.waker.register(cx.waker()); + if self.is_cancelled() { + return Poll::Ready(self.take_exception()); + } + Poll::Pending + } + + /// Retrieve the exception thrown in the associated coroutine. + pub async fn cancelled(&mut self) -> PyObject { + poll_fn(|cx| self.poll_cancelled(cx)).await + } +} + +/// [`CoroutineCancel`] handle used in +/// [`Coroutine::from_future_with_cancel`](crate::coroutine::Coroutine::from_future_with_cancel) +pub struct CancelHandle(Arc); + +impl CancelHandle { + pub(super) fn cancel(&self, py: Python<'_>, exc: PyObject) { + let ptr = self.0.exception.swap(exc.into_ptr(), Ordering::Relaxed); + drop(unsafe { PyObject::from_owned_ptr_or_opt(py, ptr) }); + self.0.waker.wake(); + } +} diff --git a/src/coroutine/waker.rs b/src/coroutine/waker.rs new file mode 100644 index 00000000000..0f506315d1f --- /dev/null +++ b/src/coroutine/waker.rs @@ -0,0 +1,96 @@ +use crate::sync::GILOnceCell; +use crate::types::PyCFunction; +use crate::{intern, wrap_pyfunction, Py, PyAny, PyObject, PyResult, Python}; +use futures::task::ArcWake; +use pyo3_macros::pyfunction; +use std::sync::Arc; + +/// Lazy `asyncio.Future` wrapper, implementing [`ArcWake`] by calling `Future.set_result`. +/// +/// asyncio future is let uninitialized until [`initialize_future`][1] is called. +/// If [`wake`][2] is called before future initialization (during Rust future polling), +/// [`initialize_future`][1] will return `None` (it is roughly equivalent to `asyncio.sleep(0)`) +/// +/// [1]: AsyncioWaker::initialize_future +/// [2]: AsyncioWaker::wake +pub struct AsyncioWaker(GILOnceCell>); + +impl AsyncioWaker { + pub(super) fn new() -> Self { + Self(GILOnceCell::new()) + } + + pub(super) fn reset(&mut self) { + self.0.take(); + } + + pub(super) fn initialize_future<'a>(&'a self, py: Python<'a>) -> PyResult> { + let init = || LoopAndFuture::new(py).map(Some); + let loop_and_future = self.0.get_or_try_init(py, init)?.as_ref(); + Ok(loop_and_future.map(|LoopAndFuture { future, .. }| future.as_ref(py))) + } +} + +impl ArcWake for AsyncioWaker { + fn wake_by_ref(arc_self: &Arc) { + Python::with_gil(|gil| { + if let Some(loop_and_future) = arc_self.0.get_or_init(gil, || None) { + loop_and_future + .set_result(gil) + .expect("unexpected error in coroutine waker"); + } + }); + } +} + +struct LoopAndFuture { + event_loop: PyObject, + future: PyObject, +} + +impl LoopAndFuture { + fn new(py: Python<'_>) -> PyResult { + static GET_RUNNING_LOOP: GILOnceCell = GILOnceCell::new(); + let import = || -> PyResult<_> { + let module = py.import("asyncio")?; + Ok(module.getattr("get_running_loop")?.into()) + }; + let event_loop = GET_RUNNING_LOOP.get_or_try_init(py, import)?.call0(py)?; + let future = event_loop.call_method0(py, "create_future")?; + Ok(Self { event_loop, future }) + } + + fn set_result(&self, py: Python<'_>) -> PyResult<()> { + static RELEASE_WAITER: GILOnceCell> = GILOnceCell::new(); + let release_waiter = RELEASE_WAITER + .get_or_try_init(py, || wrap_pyfunction!(release_waiter, py).map(Into::into))?; + // `Future.set_result` must be called in event loop thread, + // so it requires `call_soon_threadsafe` + let call_soon_threadsafe = self.event_loop.call_method1( + py, + intern!(py, "call_soon_threadsafe"), + (release_waiter, self.future.as_ref(py)), + ); + if let Err(err) = call_soon_threadsafe { + // `call_soon_threadsafe` will raise if the event loop is closed; + // instead of catching an unspecific `RuntimeError`, check directly if it's closed. + let is_closed = self.event_loop.call_method0(py, "is_closed")?; + if !is_closed.extract(py)? { + return Err(err); + } + } + Ok(()) + } +} + +/// Call `future.set_result` if the future is not done. +/// +/// Future can be cancelled by the event loop before being waken. +/// See https://github.com/python/cpython/blob/main/Lib/asyncio/tasks.py#L452C5-L452C5 +#[pyfunction(crate = "crate")] +fn release_waiter(future: &PyAny) -> PyResult<()> { + if !future.call_method0("done")?.extract::()? { + future.call_method1("set_result", (future.py().None(),))?; + } + Ok(()) +} diff --git a/src/impl_/coroutine.rs b/src/impl_/coroutine.rs index 989a97e219b..c3cbde23fe8 100644 --- a/src/impl_/coroutine.rs +++ b/src/impl_/coroutine.rs @@ -1,10 +1,10 @@ -use crate::coroutine::Coroutine; +use crate::coroutine::{CancelHandle, Coroutine}; use crate::impl_::wrap::OkWrap; use crate::{IntoPy, PyErr, PyObject, Python}; use std::future::Future; /// Used to wrap the result of async `#[pyfunction]` and `#[pymethods]`. -pub fn wrap_future(future: F) -> Coroutine +pub fn wrap_future(future: F, cancel: Option) -> Coroutine where F: Future + Send + 'static, R: OkWrap + Send, @@ -12,8 +12,13 @@ where R::Error: Send, PyErr: From, { - Coroutine::from_future(async move { + let future = async move { // SAFETY: GIL is acquired when future is polled (see `Coroutine::poll`) future.await.wrap(unsafe { Python::assume_gil_acquired() }) - }) + }; + if let Some(cancel) = cancel { + Coroutine::from_future_with_cancel(future, cancel) + } else { + Coroutine::from_future(future) + } } diff --git a/tests/test_coroutine.rs b/tests/test_coroutine.rs index 375608787c8..0ec85ce89a2 100644 --- a/tests/test_coroutine.rs +++ b/tests/test_coroutine.rs @@ -1,4 +1,6 @@ use futures::channel::oneshot; +use futures::FutureExt; +use pyo3::coroutine::CoroutineCancel; use pyo3::prelude::*; use pyo3::py_run; use std::future::poll_fn; @@ -84,3 +86,32 @@ fn cancelled_coroutine() { assert_eq!(err.value(gil).get_type().name().unwrap(), "CancelledError"); }) } + +#[test] +fn coroutine_cancel() { + #[pyfunction] + async fn cancellable_sleep(seconds: f64, mut cancel: CoroutineCancel) -> usize { + futures::select! { + _ = sleep(seconds).fuse() => 42, + _ = cancel.cancelled().fuse() => 0, + } + } + Python::with_gil(|gil| { + let cancellable_sleep = wrap_pyfunction!(cancellable_sleep, gil).unwrap(); + let test = r#" + import asyncio; + async def main(): + task = asyncio.create_task(cancellable_sleep(1)) + await asyncio.sleep(0) + task.cancel() + return await task + assert asyncio.run(main()) == 0 + "#; + let globals = gil.import("__main__").unwrap().dict(); + globals + .set_item("cancellable_sleep", cancellable_sleep) + .unwrap(); + gil.run(&pyo3::unindent::unindent(test), Some(globals), None) + .unwrap(); + }) +}