Skip to content

Commit

Permalink
revise and document cancellation semantics
Browse files Browse the repository at this point in the history
in short, cancellable threads always use system tasks. normal threads use the host task, unless passed a token
  • Loading branch information
richardsheridan committed Oct 15, 2023
1 parent eab30c4 commit 5d93ed9
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 101 deletions.
13 changes: 12 additions & 1 deletion docs/source/reference-core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1823,9 +1823,20 @@ to spawn a child thread, and then use a :ref:`memory channel

.. literalinclude:: reference-core/from-thread-example.py

.. note::

The ``from_thread.run*`` functions reuse the host task that called
:func:`trio.to_thread.run_sync` to run your provided function in the typical case,
namely when ``cancellable=False`` so Trio can be sure that the task will always be
around to perform the work. If you pass ``cancellable=True`` at the outset, or if
you provide a :class:`~trio.lowlevel.TrioToken` when calling back in to Trio, your
functions will be executed in a new system task. Therefore, the
:func:`~trio.lowlevel.current_task`, :func:`current_effective_deadline`, or other
task-tree specific values may differ depending on keyword argument values.

You can also use :func:`trio.from_thread.check_cancelled` to check for cancellation from
a thread that was spawned by :func:`trio.to_thread.run_sync`. If the call to
:func:`~trio.to_thread.run_sync` was cancelled, then
:func:`~trio.to_thread.run_sync` was cancelled (even if ``cancellable=False``!), then
:func:`~trio.from_thread.check_cancelled` will raise :func:`trio.Cancelled`.
It's like ``trio.from_thread.run(trio.sleep, 0)``, but much faster.

Expand Down
34 changes: 19 additions & 15 deletions trio/_tests/test_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,14 +933,16 @@ async def async_time_bomb():
async def test_from_thread_check_cancelled():
q = stdlib_queue.Queue()

async def child(cancellable):
record.append("start")
try:
return await to_thread_run_sync(f, cancellable=cancellable)
except _core.Cancelled:
record.append("cancel")
finally:
record.append("exit")
async def child(cancellable, scope):
with scope:
record.append("start")
try:
return await to_thread_run_sync(f, cancellable=cancellable)
except _core.Cancelled:
record.append("cancel")
raise
finally:
record.append("exit")

def f():
try:
Expand All @@ -956,7 +958,7 @@ def f():
record = []
ev = threading.Event()
async with _core.open_nursery() as nursery:
nursery.start_soon(child, False)
nursery.start_soon(child, False, _core.CancelScope())
await wait_all_tasks_blocked()
assert record[0] == "start"
assert q.get(timeout=1) == "Not Cancelled"
Expand All @@ -968,14 +970,15 @@ def f():
# the appropriate cancel scope
record = []
ev = threading.Event()
scope = _core.CancelScope() # Nursery cancel scope gives false positives
async with _core.open_nursery() as nursery:
nursery.start_soon(child, False)
nursery.start_soon(child, False, scope)
await wait_all_tasks_blocked()
assert record[0] == "start"
assert q.get(timeout=1) == "Not Cancelled"
nursery.cancel_scope.cancel()
scope.cancel()
ev.set()
assert nursery.cancel_scope.cancelled_caught
assert scope.cancelled_caught
assert "cancel" in record
assert record[-1] == "exit"

Expand All @@ -992,13 +995,14 @@ def f(): # noqa: F811

record = []
ev = threading.Event()
scope = _core.CancelScope()
async with _core.open_nursery() as nursery:
nursery.start_soon(child, True)
nursery.start_soon(child, True, scope)
await wait_all_tasks_blocked()
assert record[0] == "start"
nursery.cancel_scope.cancel()
scope.cancel()
ev.set()
assert nursery.cancel_scope.cancelled_caught
assert scope.cancelled_caught
assert "cancel" in record
assert record[-1] == "exit"
assert q.get(timeout=1) == "Cancelled"
Expand Down
163 changes: 78 additions & 85 deletions trio/_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from trio._core._traps import RaiseCancelT

from ._core import (
CancelScope,
RunVar,
TrioToken,
disable_ki_protection,
Expand All @@ -35,6 +34,7 @@ class _ParentTaskData(threading.local):
parent task of native Trio threads."""

token: TrioToken
abandon_on_cancel: bool
cancel_register: list[RaiseCancelT | None]
task_register: list[trio.lowlevel.Task | None]

Expand Down Expand Up @@ -74,11 +74,6 @@ class ThreadPlaceholder:


# Types for the to_thread_run_sync message loop
@attr.s(frozen=True, eq=False)
class ThreadDone(Generic[RetT]):
result: outcome.Outcome[RetT] = attr.ib()


@attr.s(frozen=True, eq=False)
class Run(Generic[RetT]):
afn: Callable[..., Awaitable[RetT]] = attr.ib()
Expand All @@ -87,7 +82,6 @@ class Run(Generic[RetT]):
queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib(
init=False, factory=stdlib_queue.SimpleQueue
)
scope: CancelScope = attr.ib(init=False, factory=CancelScope)

@disable_ki_protection
async def unprotected_afn(self) -> RetT:
Expand All @@ -108,14 +102,32 @@ async def run(self) -> None:
await trio.lowlevel.cancel_shielded_checkpoint()

async def run_system(self) -> None:
# NOTE: There is potential here to only conditionally enter a CancelScope
# when we need it, sparing some computation. But doing so adds substantial
# complexity, so we'll leave it until real need is demonstrated.
with self.scope:
result = await outcome.acapture(self.unprotected_afn)
assert not self.scope.cancelled_caught, "any Cancelled should go to our parent"
result = await outcome.acapture(self.unprotected_afn)
self.queue.put_nowait(result)

Check warning on line 106 in trio/_threads.py

View check run for this annotation

Codecov / codecov/patch

trio/_threads.py#L105-L106

Added lines #L105 - L106 were not covered by tests

def run_in_host_task(self, token: TrioToken) -> None:
task_register = PARENT_TASK_DATA.task_register

def in_trio_thread() -> None:
task = task_register[0]
assert task is not None, "guaranteed by abandon_on_cancel semantics"
trio.lowlevel.reschedule(task, outcome.Value(self))

token.run_sync_soon(in_trio_thread)

def run_in_system_nursery(self, token: TrioToken) -> None:
def in_trio_thread() -> None:
try:
trio.lowlevel.spawn_system_task(
self.run, name=self.afn, context=self.context
)
except RuntimeError: # system nursery is closed
self.queue.put_nowait(
outcome.Error(trio.RunFinishedError("system nursery is closed"))
)

token.run_sync_soon(in_trio_thread)


@attr.s(frozen=True, eq=False)
class RunSync(Generic[RetT]):
Expand Down Expand Up @@ -144,6 +156,19 @@ def run_sync(self) -> None:
result = outcome.capture(self.context.run, self.unprotected_fn)
self.queue.put_nowait(result)

def run_in_host_task(self, token: TrioToken) -> None:
task_register = PARENT_TASK_DATA.task_register

def in_trio_thread() -> None:
task = task_register[0]
assert task is not None, "guaranteed by abandon_on_cancel semantics"
trio.lowlevel.reschedule(task, outcome.Value(self))

token.run_sync_soon(in_trio_thread)

def run_in_system_nursery(self, token: TrioToken) -> None:
token.run_sync_soon(self.run_sync)


@enable_ki_protection # Decorator used on function with Coroutine[Any, Any, RetT]
async def to_thread_run_sync( # type: ignore[misc]
Expand Down Expand Up @@ -237,7 +262,7 @@ async def to_thread_run_sync( # type: ignore[misc]
"""
await trio.lowlevel.checkpoint_if_cancelled()
cancellable = bool(cancellable) # raise early if cancellable.__bool__ raises
abandon_on_cancel = bool(cancellable) # raise early if cancellable.__bool__ raises
if limiter is None:
limiter = current_default_thread_limiter()

Expand Down Expand Up @@ -266,9 +291,7 @@ def do_release_then_return_result() -> RetT:

result = outcome.capture(do_release_then_return_result)
if task_register[0] is not None:
trio.lowlevel.reschedule(
task_register[0], outcome.Value(ThreadDone(result))
)
trio.lowlevel.reschedule(task_register[0], outcome.Value(result))

current_trio_token = trio.lowlevel.current_trio_token()

Expand All @@ -283,6 +306,7 @@ def worker_fn() -> RetT:
current_async_library_cvar.set(None)

PARENT_TASK_DATA.token = current_trio_token
PARENT_TASK_DATA.abandon_on_cancel = abandon_on_cancel
PARENT_TASK_DATA.cancel_register = cancel_register
PARENT_TASK_DATA.task_register = task_register
try:
Expand All @@ -299,6 +323,7 @@ def worker_fn() -> RetT:
return ret
finally:
del PARENT_TASK_DATA.token
del PARENT_TASK_DATA.abandon_on_cancel
del PARENT_TASK_DATA.cancel_register
del PARENT_TASK_DATA.task_register

Expand Down Expand Up @@ -336,11 +361,11 @@ def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort:

while True:
# wait_task_rescheduled return value cannot be typed
msg_from_thread: ThreadDone[RetT] | Run[object] | RunSync[
msg_from_thread: outcome.Outcome[RetT] | Run[object] | RunSync[
object
] = await trio.lowlevel.wait_task_rescheduled(abort)
if isinstance(msg_from_thread, ThreadDone):
return msg_from_thread.result.unwrap() # type: ignore[no-any-return]
if isinstance(msg_from_thread, outcome.Outcome):
return msg_from_thread.unwrap() # type: ignore[no-any-return]
elif isinstance(msg_from_thread, Run):
await msg_from_thread.run()
elif isinstance(msg_from_thread, RunSync):
Expand All @@ -354,10 +379,10 @@ def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort:


def from_thread_check_cancelled() -> None:
"""Raise trio.Cancelled if the associated Trio task entered a cancelled status.
"""Raise `trio.Cancelled` if the associated Trio task entered a cancelled status.
Only applicable to threads spawned by `trio.to_thread.run_sync`. Poll to allow
``cancellable=False`` threads to raise :exc:`trio.Cancelled` at a suitable
``cancellable=False`` threads to raise :exc:`~trio.Cancelled` at a suitable
place, or to end abandoned ``cancellable=True`` threads sooner than they may
otherwise.
Expand All @@ -366,6 +391,13 @@ def from_thread_check_cancelled() -> None:
delivery of cancellation attempted against it, regardless of the value of
``cancellable`` supplied as an argument to it.
RuntimeError: If this thread is not spawned from `trio.to_thread.run_sync`.
.. note::
The check for cancellation attempts of ``cancellable=False`` threads is
interrupted while executing ``from_thread.run*`` functions, which can lead to
edge cases where this function may raise or not depending on the timing of
:class:`~trio.CancelScope` shields being raised or lowered in the Trio threads.
"""
try:
raise_cancel = PARENT_TASK_DATA.cancel_register[0]
Expand Down Expand Up @@ -406,49 +438,6 @@ def _check_token(trio_token: TrioToken | None) -> TrioToken:
return trio_token


def _send_message_to_host_task(
message: Run[RetT] | RunSync[RetT], trio_token: TrioToken
) -> None:
task_register = PARENT_TASK_DATA.task_register

def in_trio_thread() -> None:
task = task_register[0]
if task is None:
# Our parent task is gone! Punt to a system task.
if isinstance(message, Run):
message.scope.cancel()
_send_message_to_system_task(message, trio_token)
else:
trio.lowlevel.reschedule(task, outcome.Value(message))

trio_token.run_sync_soon(in_trio_thread)


def _send_message_to_system_task(
message: Run[RetT] | RunSync[RetT], trio_token: TrioToken
) -> None:
if type(message) is RunSync:
run_sync = message.run_sync
elif type(message) is Run:

def run_sync() -> None:
try:
trio.lowlevel.spawn_system_task(
message.run_system, name=message.afn, context=message.context
)
except RuntimeError: # system nursery is closed
message.queue.put_nowait(
outcome.Error(trio.RunFinishedError("system nursery is closed"))
)

else: # pragma: no cover, internal debugging guard TODO: use assert_never
raise TypeError(
"trio.to_thread.run_sync received unrecognized thread message {!r}."
"".format(message)
)
trio_token.run_sync_soon(run_sync)


def from_thread_run(
afn: Callable[..., Awaitable[RetT]],
*args: object,
Expand All @@ -467,17 +456,15 @@ def from_thread_run(
RunFinishedError: if the corresponding call to :func:`trio.run` has
already completed, or if the run has started its final cleanup phase
and can no longer spawn new system tasks.
Cancelled: if the corresponding `trio.to_thread.run_sync` task is
cancellable and exits before this function is called, or
if the task enters cancelled status or call to :func:`trio.run` completes
while ``afn(*args)`` is running, then ``afn`` is likely to raise
Cancelled: if the task enters cancelled status or call to :func:`trio.run`
completes while ``afn(*args)`` is running, then ``afn`` is likely to raise
:exc:`trio.Cancelled`.
RuntimeError: if you try calling this from inside the Trio thread,
which would otherwise cause a deadlock, or if no ``trio_token`` was
provided, and we can't infer one from context.
TypeError: if ``afn`` is not an asynchronous function.
**Locating a Trio Token**: There are two ways to specify which
**Locating a TrioToken**: There are two ways to specify which
`trio.run` loop to reenter:
- Spawn this thread from `trio.to_thread.run_sync`. Trio will
Expand All @@ -486,17 +473,20 @@ def from_thread_run(
- Pass a keyword argument, ``trio_token`` specifying a specific
`trio.run` loop to re-enter. This is useful in case you have a
"foreign" thread, spawned using some other framework, and still want
to enter Trio, or if you want to avoid the cancellation context of
`trio.to_thread.run_sync`.
to enter Trio, or if you want to use a new system task to call ``afn``,
maybe to avoid the cancellation context of a corresponding
`trio.to_thread.run_sync` task.
"""
if trio_token is None:
send_message = _send_message_to_host_task
else:
send_message = _send_message_to_system_task
token_provided = trio_token is not None
trio_token = _check_token(trio_token)

message_to_trio = Run(afn, args, contextvars.copy_context())

send_message(message_to_trio, _check_token(trio_token))
if token_provided or PARENT_TASK_DATA.abandon_on_cancel:
message_to_trio.run_in_system_nursery(trio_token)
else:
message_to_trio.run_in_host_task(trio_token)

return message_to_trio.queue.get().unwrap() # type: ignore[no-any-return]


Expand All @@ -522,7 +512,7 @@ def from_thread_run_sync(
provided, and we can't infer one from context.
TypeError: if ``fn`` is an async function.
**Locating a Trio Token**: There are two ways to specify which
**Locating a TrioToken**: There are two ways to specify which
`trio.run` loop to reenter:
- Spawn this thread from `trio.to_thread.run_sync`. Trio will
Expand All @@ -531,15 +521,18 @@ def from_thread_run_sync(
- Pass a keyword argument, ``trio_token`` specifying a specific
`trio.run` loop to re-enter. This is useful in case you have a
"foreign" thread, spawned using some other framework, and still want
to enter Trio, or if you want to avoid the cancellation context of
`trio.to_thread.run_sync`.
to enter Trio, or if you want to use a new system task to call ``fn``,
maybe to avoid the cancellation context of a corresponding
`trio.to_thread.run_sync` task.
"""
if trio_token is None:
send_message = _send_message_to_host_task
else:
send_message = _send_message_to_system_task
token_provided = trio_token is not None
trio_token = _check_token(trio_token)

message_to_trio = RunSync(fn, args, contextvars.copy_context())

send_message(message_to_trio, _check_token(trio_token))
if token_provided or PARENT_TASK_DATA.abandon_on_cancel:
message_to_trio.run_in_system_nursery(trio_token)
else:
message_to_trio.run_in_host_task(trio_token)

return message_to_trio.queue.get().unwrap() # type: ignore[no-any-return]

0 comments on commit 5d93ed9

Please sign in to comment.