Skip to content

Commit

Permalink
Get rid of a bunch more Anys and resolve mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
CoolCat467 committed Oct 28, 2024
1 parent 521c1b7 commit 55964ad
Show file tree
Hide file tree
Showing 16 changed files with 172 additions and 89 deletions.
8 changes: 5 additions & 3 deletions src/trio/_core/_concat_tb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from types import TracebackType
from typing import ClassVar, cast
from typing import TYPE_CHECKING, ClassVar, cast

################################################################
# concat_tb
Expand Down Expand Up @@ -88,7 +88,7 @@ def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackT
# cpython/pypy in current type checkers.
def controller( # type: ignore[no-any-unimported]

Check warning on line 89 in src/trio/_core/_concat_tb.py

View check run for this annotation

Codecov / codecov/patch

src/trio/_core/_concat_tb.py#L89

Added line #L89 was not covered by tests
operation: tputil.ProxyOperation,
) -> object | None:
) -> TracebackType | None:
# Rationale for pragma: I looked fairly carefully and tried a few
# things, and AFAICT it's not actually possible to get any
# 'opname' that isn't __getattr__ or __getattribute__. So there's
Expand All @@ -101,8 +101,10 @@ def controller( # type: ignore[no-any-unimported]
"__getattr__",
}
and operation.args[0] == "tb_next"
): # pragma: no cover
) or TYPE_CHECKING: # pragma: no cover
return tb_next
if TYPE_CHECKING:
raise RuntimeError("Should not be possible")
return operation.delegate() # Delegate is reverting to original behaviour

return cast(
Expand Down
8 changes: 7 additions & 1 deletion src/trio/_core/_ki.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,13 @@ class _IdRef(weakref.ref[_T]):
__slots__ = ("_hash",)
_hash: int

def __new__(cls, ob: _T, callback: Callable[[Self], Any] | None = None, /) -> Self:
# Explicit "Any" is not allowed
def __new__( # type: ignore[misc]
cls,
ob: _T,
callback: Callable[[Self], Any] | None = None,
/,
) -> Self:
self: Self = weakref.ref.__new__(cls, ob, callback)
self._hash = object.__hash__(ob)
return self
Expand Down
6 changes: 2 additions & 4 deletions src/trio/_core/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,16 @@
# for some strange reason Sphinx works with outcome.Outcome, but not Outcome, in
# start_guest_run. Same with types.FrameType in iter_await_frames
import outcome
from typing_extensions import ParamSpec, Self, TypeVar, TypeVarTuple, Unpack
from typing_extensions import Self, TypeVar, TypeVarTuple, Unpack

PosArgT = TypeVarTuple("PosArgT")
StatusT = TypeVar("StatusT", default=None)
StatusT_contra = TypeVar("StatusT_contra", contravariant=True, default=None)
PS = ParamSpec("PS")
else:
from typing import TypeVar

StatusT = TypeVar("StatusT")
StatusT_contra = TypeVar("StatusT_contra", contravariant=True)
PS = TypeVar("PS")

RetT = TypeVar("RetT")

Expand All @@ -103,7 +101,7 @@ class _NoStatus(metaclass=NoPublicConstructor):

# Decorator to mark methods public. This does nothing by itself, but
# trio/_tools/gen_exports.py looks for it.
def _public(fn: Callable[PS, RetT]) -> Callable[PS, RetT]:
def _public(fn: RetT) -> RetT:
return fn


Expand Down
61 changes: 42 additions & 19 deletions src/trio/_core/_tests/test_guest_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
import time
import traceback
import warnings
from collections.abc import AsyncGenerator, Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence
from functools import partial
from math import inf
from typing import (
TYPE_CHECKING,
Any,
NoReturn,
TypeVar,
)
Expand All @@ -26,7 +25,7 @@

import trio
import trio.testing
from trio.abc import Instrument
from trio.abc import Clock, Instrument

from ..._util import signal_raise
from .tutil import gc_collect_harder, restore_unraisablehook
Expand All @@ -37,7 +36,7 @@
from trio._channel import MemorySendChannel

T = TypeVar("T")
InHost: TypeAlias = Callable[[object], None]
InHost: TypeAlias = Callable[[Callable[[], object]], None]


# The simplest possible "host" loop.
Expand All @@ -46,12 +45,15 @@
# our main
# - final result is returned
# - any unhandled exceptions cause an immediate crash
# Explicit "Any" is not allowed
def trivial_guest_run( # type: ignore[misc]
trio_fn: Callable[..., Awaitable[T]],
def trivial_guest_run(
trio_fn: Callable[[InHost], Awaitable[T]],
*,
in_host_after_start: Callable[[], None] | None = None,
**start_guest_run_kwargs: Any,
host_uses_signal_set_wakeup_fd: bool = False,
clock: Clock | None = None,
instruments: Sequence[Instrument] = (),
restrict_keyboard_interrupt_to_checkpoints: bool = False,
strict_exception_groups: bool = True,
) -> T:
todo: queue.Queue[tuple[str, Outcome[T] | Callable[[], object]]] = queue.Queue()

Expand Down Expand Up @@ -87,7 +89,11 @@ def done_callback(outcome: Outcome[T]) -> None:
run_sync_soon_threadsafe=run_sync_soon_threadsafe,
run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe,
done_callback=done_callback,
**start_guest_run_kwargs,
host_uses_signal_set_wakeup_fd=host_uses_signal_set_wakeup_fd,
clock=clock,
instruments=instruments,
restrict_keyboard_interrupt_to_checkpoints=restrict_keyboard_interrupt_to_checkpoints,
strict_exception_groups=strict_exception_groups,
)
if in_host_after_start is not None:
in_host_after_start()
Expand Down Expand Up @@ -171,10 +177,16 @@ async def early_task() -> None:
assert res == "ok"
assert set(record) == {"system task ran", "main task ran", "run_sync_soon cb ran"}

class BadClock:
class BadClock(Clock):
def start_clock(self) -> NoReturn:
raise ValueError("whoops")

def current_time(self) -> float:
raise NotImplementedError()

def deadline_to_sleep_time(self, deadline: float) -> float:
raise NotImplementedError()

def after_start_never_runs() -> None: # pragma: no cover
pytest.fail("shouldn't get here")

Expand Down Expand Up @@ -431,12 +443,16 @@ async def abandoned_main(in_host: InHost) -> None:
trio.current_time()


# Explicit "Any" is not allowed
def aiotrio_run( # type: ignore[misc]
trio_fn: Callable[..., Awaitable[T]],
def aiotrio_run(
trio_fn: Callable[[], Awaitable[T]],
*,
pass_not_threadsafe: bool = True,
**start_guest_run_kwargs: Any,
run_sync_soon_not_threadsafe: InHost | None = None,
host_uses_signal_set_wakeup_fd: bool = False,
clock: Clock | None = None,
instruments: Sequence[Instrument] = (),
restrict_keyboard_interrupt_to_checkpoints: bool = False,
strict_exception_groups: bool = True,
) -> T:
loop = asyncio.new_event_loop()

Expand All @@ -448,13 +464,18 @@ def trio_done_callback(main_outcome: Outcome[object]) -> None:
trio_done_fut.set_result(main_outcome)

if pass_not_threadsafe:
start_guest_run_kwargs["run_sync_soon_not_threadsafe"] = loop.call_soon
run_sync_soon_not_threadsafe = loop.call_soon

trio.lowlevel.start_guest_run(
trio_fn,
run_sync_soon_threadsafe=loop.call_soon_threadsafe,
done_callback=trio_done_callback,
**start_guest_run_kwargs,
run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe,
host_uses_signal_set_wakeup_fd=host_uses_signal_set_wakeup_fd,
clock=clock,
instruments=instruments,
restrict_keyboard_interrupt_to_checkpoints=restrict_keyboard_interrupt_to_checkpoints,
strict_exception_groups=strict_exception_groups,
)

return (await trio_done_fut).unwrap() # type: ignore[no-any-return]
Expand Down Expand Up @@ -557,12 +578,14 @@ async def crash_in_worker_thread_io(in_host: InHost) -> None:
t = threading.current_thread()
old_get_events = trio._core._run.TheIOManager.get_events

# Explicit "Any" is not allowed
def bad_get_events(*args: Any) -> object: # type: ignore[misc]
def bad_get_events(
self: trio._core._run.TheIOManager,
timeout: float,
) -> trio._core._run.EventResult:
if threading.current_thread() is not t:
raise ValueError("oh no!")
else:
return old_get_events(*args)
return old_get_events(self, timeout)

m.setattr("trio._core._run.TheIOManager.get_events", bad_get_events)

Expand Down
5 changes: 4 additions & 1 deletion src/trio/_core/_tests/test_ki.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,10 @@ async def _consume_async_generator(agen: AsyncGenerator[None, None]) -> None:
await agen.aclose()


def _consume_function_for_coverage(fn: Callable[..., object]) -> None:
# Explicit "Any" is not allowed
def _consume_function_for_coverage( # type: ignore[misc]
fn: Callable[..., object],
) -> None:
result = fn()
if inspect.isasyncgen(result):
result = _consume_async_generator(result)
Expand Down
12 changes: 9 additions & 3 deletions src/trio/_core/_tests/test_parking_lot.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,10 @@ async def test_parking_lot_breaker_registration() -> None:

# registering a task as breaker on an already broken lot is fine
lot.break_lot()
child_task = None
child_task: _core.Task | None = None
async with trio.open_nursery() as nursery:
child_task = await nursery.start(dummy_task)
assert isinstance(child_task, _core.Task)
add_parking_lot_breaker(child_task, lot)
nursery.cancel_scope.cancel()
assert lot.broken_by == [task, child_task]
Expand Down Expand Up @@ -339,6 +340,9 @@ async def test_parking_lot_multiple_breakers_exit() -> None:
child_task1 = await nursery.start(dummy_task)
child_task2 = await nursery.start(dummy_task)
child_task3 = await nursery.start(dummy_task)
assert isinstance(child_task1, _core.Task)
assert isinstance(child_task2, _core.Task)
assert isinstance(child_task3, _core.Task)
add_parking_lot_breaker(child_task1, lot)
add_parking_lot_breaker(child_task2, lot)
add_parking_lot_breaker(child_task3, lot)
Expand All @@ -350,9 +354,11 @@ async def test_parking_lot_multiple_breakers_exit() -> None:

async def test_parking_lot_breaker_register_exited_task() -> None:
lot = ParkingLot()
child_task = None
child_task: _core.Task | None = None
async with trio.open_nursery() as nursery:
child_task = await nursery.start(dummy_task)
value = await nursery.start(dummy_task)
assert isinstance(value, _core.Task)
child_task = value
nursery.cancel_scope.cancel()
# trying to register an exited task as lot breaker errors
with pytest.raises(
Expand Down
12 changes: 9 additions & 3 deletions src/trio/_core/_tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,9 @@ async def task3(task_status: _core.TaskStatus[_core.CancelScope]) -> None:
await sleep_forever()

async with _core.open_nursery() as nursery:
scope: _core.CancelScope = await nursery.start(task3)
value = await nursery.start(task3)
assert isinstance(value, _core.CancelScope)
scope: _core.CancelScope = value
with pytest.raises(RuntimeError, match="from unrelated"):
scope.__exit__(None, None, None)
scope.cancel()
Expand Down Expand Up @@ -1963,15 +1965,19 @@ async def sleeping_children(

# Cancelling the setup_nursery just *before* calling started()
async with _core.open_nursery() as nursery:
target_nursery: _core.Nursery = await nursery.start(setup_nursery)
value = await nursery.start(setup_nursery)
assert isinstance(value, _core.Nursery)
target_nursery: _core.Nursery = value
await target_nursery.start(
sleeping_children,
target_nursery.cancel_scope.cancel,
)

# Cancelling the setup_nursery just *after* calling started()
async with _core.open_nursery() as nursery:
target_nursery = await nursery.start(setup_nursery)
value = await nursery.start(setup_nursery)
assert isinstance(value, _core.Nursery)
target_nursery = value
await target_nursery.start(sleeping_children, lambda: None)
target_nursery.cancel_scope.cancel()

Expand Down
11 changes: 9 additions & 2 deletions src/trio/_core/_thread_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
from functools import partial
from itertools import count
from threading import Lock, Thread
from typing import Any, Callable, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar

import outcome

if TYPE_CHECKING:
from collections.abc import Callable

RetT = TypeVar("RetT")


Expand Down Expand Up @@ -126,6 +129,8 @@ def darwin_namefunc(


class WorkerThread(Generic[RetT]):
__slots__ = ("_default_name", "_job", "_thread", "_thread_cache", "_worker_lock")

def __init__(self, thread_cache: ThreadCache) -> None:
self._job: (
tuple[
Expand Down Expand Up @@ -207,8 +212,10 @@ def _work(self) -> None:


class ThreadCache:
__slots__ = ("_idle_workers",)

def __init__(self) -> None:
# Explicit "Any" is not allowed
# Explicit "Any" not allowed
self._idle_workers: dict[WorkerThread[Any], None] = {} # type: ignore[misc]

def start_thread_soon(
Expand Down
Loading

0 comments on commit 55964ad

Please sign in to comment.