diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index eedbdd8ef..c153c2b18 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -31,7 +31,7 @@ jobs: pip list - name: Ruff run: | - ruff . + ruff check . ruff format --check . typos . - name: Tests diff --git a/src/prompt_toolkit/contrib/regular_languages/compiler.py b/src/prompt_toolkit/contrib/regular_languages/compiler.py index dd558a68a..4009d54f2 100644 --- a/src/prompt_toolkit/contrib/regular_languages/compiler.py +++ b/src/prompt_toolkit/contrib/regular_languages/compiler.py @@ -42,7 +42,7 @@ from __future__ import annotations import re -from typing import Callable, Dict, Iterable, Iterator, Pattern +from typing import Callable, Dict, Iterable, Iterator, Pattern, TypeVar, overload from typing import Match as RegexMatch from .regex_parser import ( @@ -57,9 +57,7 @@ tokenize_regex, ) -__all__ = [ - "compile", -] +__all__ = ["compile", "Match", "Variables"] # Name of the named group in the regex, matching trailing input. @@ -491,6 +489,9 @@ def end_nodes(self) -> Iterable[MatchVariable]: yield MatchVariable(varname, value, (reg[0], reg[1])) +_T = TypeVar("_T") + + class Variables: def __init__(self, tuples: list[tuple[str, str, tuple[int, int]]]) -> None: #: List of (varname, value, slice) tuples. @@ -502,7 +503,13 @@ def __repr__(self) -> str: ", ".join(f"{k}={v!r}" for k, v, _ in self._tuples), ) - def get(self, key: str, default: str | None = None) -> str | None: + @overload + def get(self, key: str) -> str | None: ... + + @overload + def get(self, key: str, default: str | _T) -> str | _T: ... + + def get(self, key: str, default: str | _T | None = None) -> str | _T | None: items = self.getall(key) return items[0] if items else default diff --git a/src/prompt_toolkit/output/defaults.py b/src/prompt_toolkit/output/defaults.py index ed114e32a..6b06ed43c 100644 --- a/src/prompt_toolkit/output/defaults.py +++ b/src/prompt_toolkit/output/defaults.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from typing import TextIO, cast +from typing import TYPE_CHECKING, TextIO, cast from prompt_toolkit.utils import ( get_bell_environment_variable, @@ -13,13 +13,17 @@ from .color_depth import ColorDepth from .plain_text import PlainTextOutput +if TYPE_CHECKING: + from prompt_toolkit.patch_stdout import StdoutProxy + + __all__ = [ "create_output", ] def create_output( - stdout: TextIO | None = None, always_prefer_tty: bool = False + stdout: TextIO | StdoutProxy | None = None, always_prefer_tty: bool = False ) -> Output: """ Return an :class:`~prompt_toolkit.output.Output` instance for the command @@ -54,13 +58,6 @@ def create_output( stdout = io break - # If the output is still `None`, use a DummyOutput. - # This happens for instance on Windows, when running the application under - # `pythonw.exe`. In that case, there won't be a terminal Window, and - # stdin/stdout/stderr are `None`. - if stdout is None: - return DummyOutput() - # If the patch_stdout context manager has been used, then sys.stdout is # replaced by this proxy. For prompt_toolkit applications, we want to use # the real stdout. @@ -69,6 +66,13 @@ def create_output( while isinstance(stdout, StdoutProxy): stdout = stdout.original_stdout + # If the output is still `None`, use a DummyOutput. + # This happens for instance on Windows, when running the application under + # `pythonw.exe`. In that case, there won't be a terminal Window, and + # stdin/stdout/stderr are `None`. + if stdout is None: + return DummyOutput() + if sys.platform == "win32": from .conemu import ConEmuOutput from .win32 import Win32Output diff --git a/src/prompt_toolkit/patch_stdout.py b/src/prompt_toolkit/patch_stdout.py index 4958e9d2e..e1f2a7a2c 100644 --- a/src/prompt_toolkit/patch_stdout.py +++ b/src/prompt_toolkit/patch_stdout.py @@ -273,7 +273,7 @@ def flush(self) -> None: self._flush() @property - def original_stdout(self) -> TextIO: + def original_stdout(self) -> TextIO | None: return self._output.stdout or sys.__stdout__ # Attributes for compatibility with sys.__stdout__: