Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overload variables class for better typing experience #1919

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
pip list
- name: Ruff
run: |
ruff .
ruff check .
ruff format --check .
typos .
- name: Tests
Expand Down
17 changes: 12 additions & 5 deletions src/prompt_toolkit/contrib/regular_languages/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from __future__ import annotations

import re
from typing import Callable, Dict, Iterable, Iterator, Pattern
from typing import Callable, Dict, Iterable, Iterator, Pattern, TypeVar, overload
from typing import Match as RegexMatch

from .regex_parser import (
Expand All @@ -57,9 +57,7 @@
tokenize_regex,
)

__all__ = [
"compile",
]
__all__ = ["compile", "Match", "Variables"]


# Name of the named group in the regex, matching trailing input.
Expand Down Expand Up @@ -491,6 +489,9 @@ def end_nodes(self) -> Iterable[MatchVariable]:
yield MatchVariable(varname, value, (reg[0], reg[1]))


_T = TypeVar("_T")


class Variables:
def __init__(self, tuples: list[tuple[str, str, tuple[int, int]]]) -> None:
#: List of (varname, value, slice) tuples.
Expand All @@ -502,7 +503,13 @@ def __repr__(self) -> str:
", ".join(f"{k}={v!r}" for k, v, _ in self._tuples),
)

def get(self, key: str, default: str | None = None) -> str | None:
@overload
def get(self, key: str) -> str | None: ...

@overload
def get(self, key: str, default: str | _T) -> str | _T: ...

def get(self, key: str, default: str | _T | None = None) -> str | _T | None:
items = self.getall(key)
return items[0] if items else default

Expand Down
22 changes: 13 additions & 9 deletions src/prompt_toolkit/output/defaults.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import sys
from typing import TextIO, cast
from typing import TYPE_CHECKING, TextIO, cast

from prompt_toolkit.utils import (
get_bell_environment_variable,
Expand All @@ -13,13 +13,17 @@
from .color_depth import ColorDepth
from .plain_text import PlainTextOutput

if TYPE_CHECKING:
from prompt_toolkit.patch_stdout import StdoutProxy


__all__ = [
"create_output",
]


def create_output(
stdout: TextIO | None = None, always_prefer_tty: bool = False
stdout: TextIO | StdoutProxy | None = None, always_prefer_tty: bool = False
) -> Output:
"""
Return an :class:`~prompt_toolkit.output.Output` instance for the command
Expand Down Expand Up @@ -54,13 +58,6 @@ def create_output(
stdout = io
break

# If the output is still `None`, use a DummyOutput.
# This happens for instance on Windows, when running the application under
# `pythonw.exe`. In that case, there won't be a terminal Window, and
# stdin/stdout/stderr are `None`.
if stdout is None:
return DummyOutput()

# If the patch_stdout context manager has been used, then sys.stdout is
# replaced by this proxy. For prompt_toolkit applications, we want to use
# the real stdout.
Expand All @@ -69,6 +66,13 @@ def create_output(
while isinstance(stdout, StdoutProxy):
stdout = stdout.original_stdout

# If the output is still `None`, use a DummyOutput.
# This happens for instance on Windows, when running the application under
# `pythonw.exe`. In that case, there won't be a terminal Window, and
# stdin/stdout/stderr are `None`.
if stdout is None:
return DummyOutput()

if sys.platform == "win32":
from .conemu import ConEmuOutput
from .win32 import Win32Output
Expand Down
2 changes: 1 addition & 1 deletion src/prompt_toolkit/patch_stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def flush(self) -> None:
self._flush()

@property
def original_stdout(self) -> TextIO:
def original_stdout(self) -> TextIO | None:
return self._output.stdout or sys.__stdout__

# Attributes for compatibility with sys.__stdout__:
Expand Down
Loading