Skip to content

Commit

Permalink
Avoid rust task-impl on Python >= 3.12
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro committed Feb 8, 2025
1 parent df62e03 commit f3f556e
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 100 deletions.
5 changes: 5 additions & 0 deletions granian/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import sys


_PYV = int(sys.version_info.major * 100 + sys.version_info.minor)
_PY_312 = 312
7 changes: 6 additions & 1 deletion granian/server/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, Type, TypeVar

from .._compat import _PY_312, _PYV
from .._imports import setproctitle, watchfiles
from .._internal import load_target
from .._signals import set_main_signals
Expand Down Expand Up @@ -468,7 +469,11 @@ def serve(
raise ConfigurationError('workers_lifetime')

if self.task_impl == TaskImpl.rust:
logger.warning('Rust task implementation is experimental!')
if _PYV >= _PY_312:
self.task_impl = TaskImpl.asyncio
logger.warning('Rust task implementation is not available on Python >= 3.12, falling back to asyncio')
else:
logger.warning('Rust task implementation is experimental!')

serve_method = self._serve_with_reloader if self.reload_on_changes else self._serve
serve_method(spawn_target, target_loader)
110 changes: 17 additions & 93 deletions src/callbacks.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#[cfg(Py_3_12)]
use std::cell::RefCell;

use pyo3::{exceptions::PyStopIteration, prelude::*, types::PyDict, IntoPyObjectExt};
use std::sync::{atomic, Arc, OnceLock, RwLock};
use tokio::sync::Notify;
Expand Down Expand Up @@ -43,61 +40,7 @@ impl CallbackScheduler {
watcher.drop_ref(py);
}

#[cfg(Py_3_12)]
#[inline]
pub(crate) fn send(pyself: Py<Self>, py: Python, state: Py<CallbackSchedulerState>) {
let rself = pyself.get();
let rstate = state.borrow(py);
let aiotask = rself.aio_task.as_ptr();

*rstate.futw.borrow_mut() = None;

unsafe {
pyo3::ffi::PyObject_CallOneArg(rself.aio_tenter.as_ptr(), aiotask);
}

if let Some(res) = unsafe {
let mut pres = std::ptr::null_mut::<pyo3::ffi::PyObject>();
// FIXME: use PyIter_Send return value once available in PyO3
pyo3::ffi::PyIter_Send(rstate.coro.as_ptr(), rself.pynone.as_ptr(), &mut pres);
Bound::from_owned_ptr_or_opt(py, pres)
.map(|v| {
if v.is_none() {
return None;
}
Some(v)
})
.unwrap()
} {
if unsafe {
let vptr = pyo3::ffi::PyObject_GetAttr(res.as_ptr(), rself.pyname_aioblock.as_ptr());
Bound::from_owned_ptr_or_err(py, vptr)
.map(|v| v.extract::<bool>().unwrap_or(false))
.unwrap_or(false)
} {
let resp = res.as_ptr();
*rstate.futw.borrow_mut() = Some(res.unbind().clone_ref(py));
drop(rstate);

unsafe {
pyo3::ffi::PyObject_SetAttr(resp, rself.pyname_aioblock.as_ptr(), rself.pyfalse.as_ptr());
CallbackSchedulerState::schedule(state, py, resp);
}
} else {
drop(rstate);
CallbackSchedulerState::reschedule(state, py);
}
} else {
drop(rstate);
state.drop_ref(py);
}

unsafe {
pyo3::ffi::PyObject_CallOneArg(rself.aio_texit.as_ptr(), aiotask);
}
}

#[cfg(all(not(Py_3_12), Py_3_10))]
#[cfg(Py_3_10)]
#[inline]
pub(crate) fn send(pyself: Py<Self>, py: Python, state: Py<CallbackSchedulerState>) {
let rself = pyself.get();
Expand All @@ -110,7 +53,7 @@ impl CallbackScheduler {
if let Some(res) = unsafe {
let mut pres = std::ptr::null_mut::<pyo3::ffi::PyObject>();
// FIXME: use PyIter_Send return value once available in PyO3
pyo3::ffi::PyIter_Send(state.borrow(py).coro.as_ptr(), rself.pynone.as_ptr(), &mut pres);
pyo3::ffi::PyIter_Send(state.get().coro.as_ptr(), rself.pynone.as_ptr(), &mut pres);
Bound::from_owned_ptr_or_opt(py, pres)
.map(|v| {
if v.is_none() {
Expand All @@ -132,7 +75,7 @@ impl CallbackScheduler {
CallbackSchedulerState::schedule(state, py, resp);
}
} else {
CallbackSchedulerState::reschedule(state, py);
CallbackSchedulerState::reschedule(state);
}
} else {
state.drop_ref(py);
Expand All @@ -155,7 +98,7 @@ impl CallbackScheduler {

if let Ok(res) = unsafe {
let pres = pyo3::ffi::PyObject_CallMethodOneArg(
state.borrow(py).coro.as_ptr(),
state.get().coro.as_ptr(),
rself.pyname_aiosend.as_ptr(),
rself.pynone.as_ptr(),
);
Expand All @@ -173,7 +116,7 @@ impl CallbackScheduler {
CallbackSchedulerState::schedule(state, py, resp);
}
} else {
CallbackSchedulerState::reschedule(state, py);
CallbackSchedulerState::reschedule(state);
}
} else {
state.drop_ref(py);
Expand All @@ -192,7 +135,7 @@ impl CallbackScheduler {
unsafe {
pyo3::ffi::PyObject_CallOneArg(rself.aio_tenter.as_ptr(), aiotask);
pyo3::ffi::PyObject_CallMethodOneArg(
state.borrow(py).coro.as_ptr(),
state.get().coro.as_ptr(),
rself.pyname_aiothrow.as_ptr(),
err.into_ptr(),
);
Expand Down Expand Up @@ -227,7 +170,7 @@ impl CallbackScheduler {

if let Ok(res) = unsafe {
let res = pyo3::ffi::PyObject_CallMethodObjArgs(
state.borrow(py).coro.as_ptr(),
state.get().coro.as_ptr(),
rself.pyname_aiosend.as_ptr(),
rself.pynone.as_ptr(),
std::ptr::null_mut::<PyObject>(),
Expand All @@ -247,7 +190,7 @@ impl CallbackScheduler {
CallbackSchedulerState::schedule(state, py, resp);
}
} else {
CallbackSchedulerState::reschedule(state, py);
CallbackSchedulerState::reschedule(state);
}
} else {
state.drop_ref(py);
Expand All @@ -266,7 +209,7 @@ impl CallbackScheduler {
unsafe {
pyo3::ffi::PyObject_CallObject(rself.aio_tenter.as_ptr(), aiotask);
pyo3::ffi::PyObject_CallMethodObjArgs(
state.borrow(py).coro.as_ptr(),
state.get().coro.as_ptr(),
rself.pyname_aiothrow.as_ptr(),
(err,).into_py_any(py).unwrap().into_ptr(),
std::ptr::null_mut::<PyObject>(),
Expand Down Expand Up @@ -341,15 +284,13 @@ impl CallbackScheduler {
}
}

#[pyclass(frozen, freelist = 1024, unsendable, module = "granian._granian")]
#[pyclass(frozen, freelist = 1024, module = "granian._granian")]
pub(crate) struct CallbackSchedulerState {
sched: Py<CallbackScheduler>,
coro: PyObject,
ctxd: Py<PyDict>,
pys_futcb: PyObject,
pym_schedule: PyObject,
#[cfg(Py_3_12)]
futw: RefCell<Option<PyObject>>,
}

impl CallbackSchedulerState {
Expand All @@ -371,24 +312,25 @@ impl CallbackSchedulerState {
ctxd: ctxd.unbind(),
pys_futcb: pyo3::intern!(py, "add_done_callback").into_py_any(py).unwrap(),
pym_schedule,
#[cfg(Py_3_12)]
futw: RefCell::new(None),
},
)
.unwrap()
}

unsafe fn schedule(pyself: Py<Self>, py: Python, step: *mut pyo3::ffi::PyObject) {
let rself = pyself.borrow(py);
let rself = pyself.get();
pyo3::ffi::PyObject_Call(
pyo3::ffi::PyObject_GetAttr(step, rself.pys_futcb.as_ptr()),
pyself.getattr(py, pyo3::intern!(py, "wake")).unwrap().as_ptr(),
(pyself.getattr(py, pyo3::intern!(py, "wake")).unwrap(),)
.into_py_any(py)
.unwrap()
.as_ptr(),
rself.ctxd.as_ptr(),
);
}

fn reschedule(pyself: Py<Self>, py: Python) {
let rself = pyself.borrow(py);
fn reschedule(pyself: Py<Self>) {
let rself = pyself.get();
unsafe {
pyo3::ffi::PyObject_Call(rself.pym_schedule.as_ptr(), pyself.as_ptr(), rself.ctxd.as_ptr());
}
Expand All @@ -409,24 +351,6 @@ impl CallbackSchedulerState {
Err(err) => CallbackScheduler::throw(sched, py, pyself, err.into_py_any(py).unwrap()),
}
}

#[cfg(Py_3_12)]
fn cancel(&self, py: Python) -> PyResult<PyObject> {
if let Some(v) = self.futw.borrow().as_ref() {
return v.call_method0(py, pyo3::intern!(py, "cancel"));
}
Ok(self.sched.get().pyfalse.clone_ref(py))
}

#[cfg(Py_3_12)]
fn cancelling(&self) -> i32 {
0
}

#[cfg(Py_3_12)]
fn uncancel(&self) -> i32 {
0
}
}

#[pyclass(frozen, module = "granian._granian")]
Expand Down
7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ def _serve(**kwargs):


@asynccontextmanager
async def _server(interface, port, threading_mode, tls=False):
async def _server(interface, port, threading_mode, tls=False, task_impl='asyncio'):
certs_path = Path.cwd() / 'tests' / 'fixtures' / 'tls'
kwargs = {
'interface': interface,
'port': port,
'blocking_threads': 1,
'threading_mode': threading_mode,
'task_impl': task_impl,
}
if tls:
if tls == 'private':
Expand Down Expand Up @@ -74,8 +75,8 @@ def server_port():


@pytest.fixture(scope='function')
def asgi_server(server_port):
return partial(_server, 'asgi', server_port)
def asgi_server(server_port, **extras):
return partial(_server, 'asgi', server_port, **extras)


@pytest.fixture(scope='function')
Expand Down
6 changes: 3 additions & 3 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ async def test_file(asgi_server, threading_mode):
@pytest.mark.skipif(bool(os.getenv('PGO_RUN')), reason='PGO build')
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_sniffio(asgi_server, threading_mode):
async with asgi_server(threading_mode) as port:
async with asgi_server(threading_mode, task_impl='rust') as port:
res = httpx.get(f'http://localhost:{port}/sniffio')

assert res.status_code == 200
Expand All @@ -106,13 +106,13 @@ async def test_sniffio(asgi_server, threading_mode):
@pytest.mark.skipif(bool(os.getenv('PGO_RUN')), reason='PGO build')
@pytest.mark.parametrize('threading_mode', ['runtime', 'workers'])
async def test_timeout(asgi_server, threading_mode):
async with asgi_server(threading_mode) as port:
async with asgi_server(threading_mode, task_impl='rust') as port:
res = httpx.get(f'http://localhost:{port}/timeout_n')

assert res.status_code == 200
assert res.text == 'ok'

async with asgi_server(threading_mode) as port:
async with asgi_server(threading_mode, task_impl='rust') as port:
res = httpx.get(f'http://localhost:{port}/timeout_w')

assert res.status_code == 200
Expand Down

0 comments on commit f3f556e

Please sign in to comment.