Skip to content

Commit

Permalink
feat: add #[pyo3(allow_threads)] to release the GIL in (async) func…
Browse files Browse the repository at this point in the history
…tions
  • Loading branch information
wyfo committed Dec 7, 2023
1 parent f34c70c commit 3bc7d67
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 16 deletions.
1 change: 1 addition & 0 deletions newsfragments/3610.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `#[pyo3(allow_threads)]` to release the GIL in (async) functions
1 change: 1 addition & 0 deletions pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use syn::{
};

pub mod kw {
syn::custom_keyword!(allow_threads);
syn::custom_keyword!(annotation);
syn::custom_keyword!(attribute);
syn::custom_keyword!(cancel_handle);
Expand Down
27 changes: 26 additions & 1 deletion pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::fmt::Display;

use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned, ToTokens};
use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};

use crate::{
attributes,
attributes::{TextSignatureAttribute, TextSignatureAttributeValue},
deprecations::{Deprecation, Deprecations},
params::impl_arg_params,
Expand Down Expand Up @@ -241,6 +242,7 @@ pub struct FnSpec<'a> {
pub asyncness: Option<syn::Token![async]>,
pub unsafety: Option<syn::Token![unsafe]>,
pub deprecations: Deprecations,
pub allow_threads: Option<attributes::kw::allow_threads>,
}

pub fn get_return_info(output: &syn::ReturnType) -> syn::Type {
Expand Down Expand Up @@ -284,6 +286,7 @@ impl<'a> FnSpec<'a> {
text_signature,
name,
signature,
allow_threads,
..
} = options;

Expand Down Expand Up @@ -331,6 +334,7 @@ impl<'a> FnSpec<'a> {
asyncness: sig.asyncness,
unsafety: sig.unsafety,
deprecations,
allow_threads,
})
}

Expand Down Expand Up @@ -474,6 +478,7 @@ impl<'a> FnSpec<'a> {
}

let rust_call = |args: Vec<TokenStream>| {
let allow_threads = self.allow_threads.is_some();
let call = if self.asyncness.is_some() {
let throw_callback = if cancel_handle.is_some() {
quote! { Some(__throw_callback) }
Expand Down Expand Up @@ -502,6 +507,7 @@ impl<'a> FnSpec<'a> {
_pyo3::intern!(py, stringify!(#python_name)),
#qualname_prefix,
#throw_callback,
#allow_threads,
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) },
)
}};
Expand All @@ -513,6 +519,25 @@ impl<'a> FnSpec<'a> {
}};
}
call
} else if allow_threads {
let (self_arg_name, self_arg_decl) = if self_arg.is_empty() {
(quote!(), quote!())
} else {
(quote!(__self), quote! { let __self = #self_arg; })
};
let arg_names: Vec<Ident> = (0..args.len())
.map(|i| format_ident!("__arg{}", i))
.collect();
let arg_decls: Vec<TokenStream> = args
.into_iter()
.zip(&arg_names)
.map(|(arg, name)| quote! { let #name = #arg; })
.collect();
quote! {{
#self_arg_decl
#(#arg_decls)*
py.allow_threads(|| function(#self_arg_name #(#arg_names),*))
}}
} else {
quote! { function(#self_arg #(#args),*) }
};
Expand Down
12 changes: 10 additions & 2 deletions pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ pub struct PyFunctionOptions {
pub signature: Option<SignatureAttribute>,
pub text_signature: Option<TextSignatureAttribute>,
pub krate: Option<CrateAttribute>,
pub allow_threads: Option<attributes::kw::allow_threads>,
}

impl Parse for PyFunctionOptions {
Expand All @@ -99,7 +100,8 @@ impl Parse for PyFunctionOptions {

while !input.is_empty() {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name)
if lookahead.peek(attributes::kw::allow_threads)
|| lookahead.peek(attributes::kw::name)
|| lookahead.peek(attributes::kw::pass_module)
|| lookahead.peek(attributes::kw::signature)
|| lookahead.peek(attributes::kw::text_signature)
Expand All @@ -121,6 +123,7 @@ impl Parse for PyFunctionOptions {
}

pub enum PyFunctionOption {
AllowThreads(attributes::kw::allow_threads),
Name(NameAttribute),
PassModule(attributes::kw::pass_module),
Signature(SignatureAttribute),
Expand All @@ -131,7 +134,9 @@ pub enum PyFunctionOption {
impl Parse for PyFunctionOption {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name) {
if lookahead.peek(attributes::kw::allow_threads) {
input.parse().map(PyFunctionOption::AllowThreads)
} else if lookahead.peek(attributes::kw::name) {
input.parse().map(PyFunctionOption::Name)
} else if lookahead.peek(attributes::kw::pass_module) {
input.parse().map(PyFunctionOption::PassModule)
Expand Down Expand Up @@ -171,6 +176,7 @@ impl PyFunctionOptions {
}
for attr in attrs {
match attr {
PyFunctionOption::AllowThreads(allow_threads) => set_option!(allow_threads),
PyFunctionOption::Name(name) => set_option!(name),
PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
PyFunctionOption::Signature(signature) => set_option!(signature),
Expand Down Expand Up @@ -198,6 +204,7 @@ pub fn impl_wrap_pyfunction(
) -> syn::Result<TokenStream> {
check_generic(&func.sig)?;
let PyFunctionOptions {
allow_threads,
pass_module,
name,
signature,
Expand Down Expand Up @@ -247,6 +254,7 @@ pub fn impl_wrap_pyfunction(
signature,
output: ty,
text_signature,
allow_threads,
asyncness: func.sig.asyncness,
unsafety: func.sig.unsafety,
deprecations: Deprecations::new(),
Expand Down
1 change: 1 addition & 0 deletions pyo3-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ pub fn pymethods(attr: TokenStream, input: TokenStream) -> TokenStream {
/// | `#[pyo3(name = "...")]` | Defines the name of the function in Python. |
/// | `#[pyo3(text_signature = "...")]` | Defines the `__text_signature__` attribute of the function in Python. |
/// | `#[pyo3(pass_module)]` | Passes the module containing the function as a `&PyModule` first argument to the function. |
/// | `#[pyo3(allow_threads)]` | Release the GIL in the function body, or each time the returned future is polled for `async fn` |
///
/// For more on exposing functions see the [function section of the guide][1].
///
Expand Down
11 changes: 10 additions & 1 deletion src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub struct Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
allow_threads: bool,
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
waker: Option<Arc<AsyncioWaker>>,
}
Expand All @@ -47,6 +48,7 @@ impl Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
allow_threads: bool,
future: F,
) -> Self
where
Expand All @@ -63,6 +65,7 @@ impl Coroutine {
name,
qualname_prefix,
throw_callback,
allow_threads,
future: Some(Box::pin(wrap)),
waker: None,
}
Expand Down Expand Up @@ -96,7 +99,13 @@ impl Coroutine {
let waker = Waker::from(self.waker.clone().unwrap());
// poll the Rust future and forward its results if ready
// polling is UnwindSafe because the future is dropped in case of panic
let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker));
let poll = || {
if self.allow_threads {
py.allow_threads(|| future_rs.as_mut().poll(&mut Context::from_waker(&waker)))
} else {
future_rs.as_mut().poll(&mut Context::from_waker(&waker))
}
};
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
Ok(Poll::Ready(res)) => {
self.close();
Expand Down
70 changes: 60 additions & 10 deletions src/gil.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
//! Interaction with Python's global interpreter lock
use crate::impl_::not_send::{NotSend, NOT_SEND};
use crate::{ffi, Python};
use parking_lot::{const_mutex, Mutex, Once};
use std::cell::Cell;
#[cfg(debug_assertions)]
use std::cell::RefCell;
#[cfg(not(debug_assertions))]
use std::cell::UnsafeCell;
use std::{mem, ptr::NonNull};
use std::{cell::Cell, mem, ptr::NonNull};

use parking_lot::{const_mutex, Mutex, Once};

use crate::{
ffi,
impl_::not_send::{NotSend, NOT_SEND},
Python,
};

static START: Once = Once::new();

Expand Down Expand Up @@ -506,11 +510,13 @@ fn decrement_gil_count() {

#[cfg(test)]
mod tests {
use super::{gil_is_acquired, GILPool, GIL_COUNT, OWNED_OBJECTS, POOL};
use crate::{ffi, gil, PyObject, Python, ToPyObject};
use std::ptr::NonNull;

#[cfg(not(target_arch = "wasm32"))]
use parking_lot::{const_mutex, Condvar, Mutex};
use std::ptr::NonNull;

use super::{gil_is_acquired, GILPool, GIL_COUNT, OWNED_OBJECTS, POOL};
use crate::{ffi, gil, PyObject, Python, ToPyObject};

fn get_object(py: Python<'_>) -> PyObject {
// Convenience function for getting a single unique object, using `new_pool` so as to leave
Expand Down Expand Up @@ -786,9 +792,10 @@ mod tests {
#[test]
#[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
fn test_clone_without_gil() {
use crate::{Py, PyAny};
use std::{sync::Arc, thread};

use crate::{Py, PyAny};

// Some events for synchronizing
static GIL_ACQUIRED: Event = Event::new();
static OBJECT_CLONED: Event = Event::new();
Expand Down Expand Up @@ -851,9 +858,10 @@ mod tests {
#[test]
#[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
fn test_clone_in_other_thread() {
use crate::Py;
use std::{sync::Arc, thread};

use crate::Py;

// Some events for synchronizing
static OBJECT_CLONED: Event = Event::new();

Expand Down Expand Up @@ -925,4 +933,46 @@ mod tests {
POOL.update_counts(py);
})
}

#[cfg(feature = "macros")]
#[test]
fn allow_threads_fn() {
#[crate::pyfunction(allow_threads, crate = "crate")]
fn without_gil() {
GIL_COUNT.with(|c| assert_eq!(c.get(), 0));
}
Python::with_gil(|gil| {
let without_gil = crate::wrap_pyfunction!(without_gil, gil).unwrap();
crate::py_run!(gil, without_gil, "without_gil()");
})
}

#[cfg(feature = "macros")]
#[test]
fn allow_threads_async_fn() {
#[crate::pyfunction(allow_threads, crate = "crate")]
async fn without_gil() {
use std::task::Poll;
GIL_COUNT.with(|c| assert_eq!(c.get(), 0));
let mut ready = false;
futures::future::poll_fn(|cx| {
if ready {
return Poll::Ready(());
}
ready = true;
cx.waker().wake_by_ref();
Poll::Pending
})
.await;
GIL_COUNT.with(|c| assert_eq!(c.get(), 0));
}
Python::with_gil(|gil| {
let without_gil = crate::wrap_pyfunction!(without_gil, gil).unwrap();
crate::py_run!(
gil,
without_gil,
"import asyncio; asyncio.run(without_gil())"
);
})
}
}
9 changes: 8 additions & 1 deletion src/impl_/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,21 @@ pub fn new_coroutine<F, T, E>(
name: &PyString,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
allow_threads: bool,
future: F,
) -> Coroutine
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: IntoPy<PyObject>,
E: Into<PyErr>,
{
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
Coroutine::new(
Some(name.into()),
qualname_prefix,
throw_callback,
allow_threads,
future,
)
}

fn get_ptr<T: PyClass>(obj: &Py<T>) -> *mut T {
Expand Down
2 changes: 1 addition & 1 deletion tests/ui/invalid_pyfunction_signatures.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ error: expected argument from function definition `y` but got argument `x`
13 | #[pyo3(signature = (x))]
| ^

error: expected one of: `name`, `pass_module`, `signature`, `text_signature`, `crate`
error: expected one of: `allow_threads`, `name`, `pass_module`, `signature`, `text_signature`, `crate`
--> tests/ui/invalid_pyfunction_signatures.rs:18:14
|
18 | #[pyfunction(x)]
Expand Down

0 comments on commit 3bc7d67

Please sign in to comment.