Skip to content

Commit

Permalink
feat: handle coroutine cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Perez committed Oct 29, 2023
1 parent 9c521b1 commit a078c69
Show file tree
Hide file tree
Showing 11 changed files with 300 additions and 108 deletions.
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ unindent = { version = "0.2.1", optional = true }
inventory = { version = "0.3.0", optional = true }

# coroutine implementation
futures-task = "0.3"
futures= "0.3"

# crate integrations that can be added using the eponymous features
anyhow = { version = "1.0", optional = true }
Expand All @@ -58,7 +58,6 @@ serde_json = "1.0.61"
rayon = "1.0.2"
rust_decimal = { version = "1.8.0", features = ["std"] }
widestring = "0.5.1"
futures = "0.3.28"

[build-dependencies]
pyo3-build-config = { path = "pyo3-build-config", version = "0.20.0", features = ["resolve-config"] }
Expand Down
21 changes: 20 additions & 1 deletion pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub struct FnArg<'a> {
pub optional: Option<&'a syn::Type>,
pub default: Option<syn::Expr>,
pub py: bool,
pub coroutine_cancel: bool,
pub attrs: PyFunctionArgPyO3Attributes,
pub is_varargs: bool,
pub is_kwargs: bool,
Expand Down Expand Up @@ -50,6 +51,7 @@ impl<'a> FnArg<'a> {
optional: utils::option_type_argument(&cap.ty),
default: None,
py: utils::is_python(&cap.ty),
coroutine_cancel: utils::is_coroutine_cancel(&cap.ty),
attrs: arg_attrs,
is_varargs: false,
is_kwargs: false,
Expand Down Expand Up @@ -446,10 +448,27 @@ impl<'a> FnSpec<'a> {
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise);
let func_name = &self.name;

let coroutine_cancel = self
.signature
.arguments
.iter()
.find(|arg| arg.coroutine_cancel);
if let (None, Some(arg)) = (&self.asyncness, coroutine_cancel) {
bail_spanned!(arg.ty.span() => "`CoroutineCancel` argument only allowed with `async fn`");
}

let rust_call = |args: Vec<TokenStream>| {
let mut call = quote! { function(#self_arg #(#args),*) };
if self.asyncness.is_some() {
call = quote! { _pyo3::impl_::coroutine::wrap_future(#call) };
call = if coroutine_cancel.is_some() {
quote! {{
let __coroutine_cancel = _pyo3::coroutine::CoroutineCancel::new();
let __cancel_handle = __coroutine_cancel.handle();
_pyo3::impl_::coroutine::wrap_future({ #call }, Some(__cancel_handle))
}}
} else {
quote! { _pyo3::impl_::coroutine::wrap_future(#call, None) }
};
}
quotes::map_result_into_ptr(quotes::ok_wrap(call))
};
Expand Down
4 changes: 4 additions & 0 deletions pyo3-macros-backend/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ fn impl_arg_param(
return Ok(quote! { py });
}

if arg.coroutine_cancel {
return Ok(quote! { __coroutine_cancel });
}

let name = arg.name;
let name_str = name.to_string();

Expand Down
14 changes: 12 additions & 2 deletions pyo3-macros-backend/src/pyfunction/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,16 @@ impl<'a> FunctionSignature<'a> {
// Otherwise try next argument.
continue;
}
if fn_arg.coroutine_cancel {
// If the user incorrectly tried to include cancel: CoroutineCancel in the
// signature, give a useful error as a hint.
ensure_spanned!(
name != fn_arg.name,
name.span() => "arguments of type `CoroutineCancel` must not be part of the signature"
);
// Otherwise try next argument.
continue;
}

ensure_spanned!(
name == fn_arg.name,
Expand Down Expand Up @@ -411,7 +421,7 @@ impl<'a> FunctionSignature<'a> {
}

// Ensure no non-py arguments remain
if let Some(arg) = args_iter.find(|arg| !arg.py) {
if let Some(arg) = args_iter.find(|arg| !arg.py && !arg.coroutine_cancel) {
bail_spanned!(
attribute.kw.span() => format!("missing signature entry for argument `{}`", arg.name)
);
Expand All @@ -429,7 +439,7 @@ impl<'a> FunctionSignature<'a> {
let mut python_signature = PythonSignature::default();
for arg in &arguments {
// Python<'_> arguments don't show in Python signature
if arg.py {
if arg.py || arg.coroutine_cancel {
continue;
}

Expand Down
13 changes: 13 additions & 0 deletions pyo3-macros-backend/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ pub fn is_python(ty: &syn::Type) -> bool {
}
}

/// Check if the given type `ty` is `pyo3::coroutine::CoroutineCancel`.
pub fn is_coroutine_cancel(ty: &syn::Type) -> bool {
match unwrap_ty_group(ty) {
syn::Type::Path(typath) => typath
.path
.segments
.last()
.map(|seg| seg.ident == "CoroutineCancel")
.unwrap_or(false),
_ => false,
}
}

/// If `ty` is `Option<T>`, return `Some(T)`, else `None`.
pub fn option_type_argument(ty: &syn::Type) -> Option<&syn::Type> {
if let syn::Type::Path(syn::TypePath { path, .. }) = ty {
Expand Down
1 change: 1 addition & 0 deletions pyo3-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ pub fn pymethods(attr: TokenStream, input: TokenStream) -> TokenStream {
/// | `#[pyo3(name = "...")]` | Defines the name of the function in Python. |
/// | `#[pyo3(text_signature = "...")]` | Defines the `__text_signature__` attribute of the function in Python. |
/// | `#[pyo3(pass_module)]` | Passes the module containing the function as a `&PyModule` first argument to the function. |
/// /// | `#[pyo3(coroutine_cancel = "...")]` | (`async fn` only) Pass a `CoroutineCancel` instance to the given parameter |
///
/// For more on exposing functions see the [function section of the guide][1].
///
Expand Down
144 changes: 45 additions & 99 deletions src/coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
//! Python coroutine implementation, used notably when wrapping `async fn`
//! with `#[pyfunction]`/`#[pymethods]`.
use crate::coroutine::waker::AsyncioWaker;
use crate::exceptions::{PyRuntimeError, PyStopIteration};
use crate::pyclass::IterNextOutput;
use crate::sync::GILOnceCell;
use crate::types::{PyCFunction, PyIterator};
use crate::{intern, wrap_pyfunction, IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python};
use pyo3_macros::{pyclass, pyfunction, pymethods};
use crate::types::PyIterator;
use crate::{IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python};
use pyo3_macros::{pyclass, pymethods};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

mod cancel;
mod waker;

pub use crate::coroutine::cancel::{CancelHandle, CoroutineCancel};

const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";

/// Python coroutine wrapping a [`Future`].
#[pyclass(crate = "crate")]
pub struct Coroutine {
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
cancel: Option<CancelHandle>,
waker: Option<Arc<AsyncioWaker>>,
}

Expand All @@ -41,14 +47,40 @@ impl Coroutine {
};
Self {
future: Some(Box::pin(wrap)),
cancel: None,
waker: None,
}
}

/// Wrap a future into a Python coroutine.
///
/// Coroutine `send` polls the wrapped future, ignoring the value passed
/// (should always be `None` anyway).
///
/// Coroutine `throw` registers the exception in `cancel`, and polls the wrapped future
pub fn from_future_with_cancel<F, T, E>(future: F, cancel: CancelHandle) -> Self
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: IntoPy<PyObject> + Send,
E: Send,
PyErr: From<E>,
{
let wrap = async move {
let obj = future.await?;
// SAFETY: GIL is acquired when future is polled (see `Coroutine::poll`)
Ok(obj.into_py(unsafe { Python::assume_gil_acquired() }))
};
Self {
future: Some(Box::pin(wrap)),
cancel: Some(cancel),
waker: None,
}
}

fn poll(
&mut self,
py: Python<'_>,
throw: Option<&PyAny>,
throw: Option<PyObject>,
) -> PyResult<IterNextOutput<PyObject, PyObject>> {
// raise if the coroutine has already been run to completion
let future_rs = match self.future {
Expand All @@ -57,16 +89,20 @@ impl Coroutine {
};
// reraise thrown exception it
if let Some(exc) = throw {
self.close();
return Err(PyErr::from_value(exc));
if let Some(ref handle) = self.cancel {
handle.cancel(py, exc)
} else {
self.close();
return Err(PyErr::from_value(exc.as_ref(py)));
}
}
// create a new waker, or try to reset it in place
if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {
waker.reset();
} else {
self.waker = Some(Arc::new(AsyncioWaker::new()));
}
let waker = futures_task::waker(self.waker.clone().unwrap());
let waker = futures::task::waker(self.waker.clone().unwrap());
// poll the Rust future and forward its results if ready
if let Poll::Ready(res) = future_rs.as_mut().poll(&mut Context::from_waker(&waker)) {
self.close();
Expand Down Expand Up @@ -101,7 +137,7 @@ impl Coroutine {
iter_result(self.poll(py, None)?)
}

fn throw(&mut self, py: Python<'_>, exc: &PyAny) -> PyResult<PyObject> {
fn throw(&mut self, py: Python<'_>, exc: PyObject) -> PyResult<PyObject> {
iter_result(self.poll(py, Some(exc))?)
}

Expand All @@ -119,93 +155,3 @@ impl Coroutine {
self.poll(py, None)
}
}

/// Lazy `asyncio.Future` wrapper, implementing [`ArcWake`] by calling `Future.set_result`.
///
/// asyncio future is let uninitialized until [`initialize_future`][1] is called.
/// If [`wake`][2] is called before future initialization (during Rust future polling),
/// [`initialize_future`][1] will return `None` (it is roughly equivalent to `asyncio.sleep(0)`)
///
/// [1]: AsyncioWaker::initialize_future
/// [2]: AsyncioWaker::wake
struct AsyncioWaker(GILOnceCell<Option<LoopAndFuture>>);

impl AsyncioWaker {
fn new() -> Self {
Self(GILOnceCell::new())
}

fn reset(&mut self) {
self.0.take();
}

fn initialize_future<'a>(&'a self, py: Python<'a>) -> PyResult<Option<&'a PyAny>> {
let init = || LoopAndFuture::new(py).map(Some);
let loop_and_future = self.0.get_or_try_init(py, init)?.as_ref();
Ok(loop_and_future.map(|LoopAndFuture { future, .. }| future.as_ref(py)))
}
}

impl futures_task::ArcWake for AsyncioWaker {
fn wake_by_ref(arc_self: &Arc<Self>) {
Python::with_gil(|gil| {
if let Some(loop_and_future) = arc_self.0.get_or_init(gil, || None) {
loop_and_future
.set_result(gil)
.expect("unexpected error in coroutine waker");
}
});
}
}

struct LoopAndFuture {
event_loop: PyObject,
future: PyObject,
}

impl LoopAndFuture {
fn new(py: Python<'_>) -> PyResult<Self> {
static GET_RUNNING_LOOP: GILOnceCell<PyObject> = GILOnceCell::new();
let import = || -> PyResult<_> {
let module = py.import("asyncio")?;
Ok(module.getattr("get_running_loop")?.into())
};
let event_loop = GET_RUNNING_LOOP.get_or_try_init(py, import)?.call0(py)?;
let future = event_loop.call_method0(py, "create_future")?;
Ok(Self { event_loop, future })
}

fn set_result(&self, py: Python<'_>) -> PyResult<()> {
static RELEASE_WAITER: GILOnceCell<Py<PyCFunction>> = GILOnceCell::new();
let release_waiter = RELEASE_WAITER
.get_or_try_init(py, || wrap_pyfunction!(release_waiter, py).map(Into::into))?;
// `Future.set_result` must be called in event loop thread,
// so it requires `call_soon_threadsafe`
let call_soon_threadsafe = self.event_loop.call_method1(
py,
intern!(py, "call_soon_threadsafe"),
(release_waiter, self.future.as_ref(py)),
);
if let Err(err) = call_soon_threadsafe {
// `call_soon_threadsafe` will raise if the event loop is closed;
// instead of catching an unspecific `RuntimeError`, check directly if it's closed.
let is_closed = self.event_loop.call_method0(py, "is_closed")?;
if !is_closed.extract(py)? {
return Err(err);
}
}
Ok(())
}
}

/// Call `future.set_result` if the future is not done.
///
/// Future can be cancelled by the event loop before being waken.
/// See https://github.com/python/cpython/blob/main/Lib/asyncio/tasks.py#L452C5-L452C5
#[pyfunction(crate = "crate")]
fn release_waiter(future: &PyAny) -> PyResult<()> {
if !future.call_method0("done")?.extract::<bool>()? {
future.call_method1("set_result", (future.py().None(),))?;
}
Ok(())
}
Loading

0 comments on commit a078c69

Please sign in to comment.