Skip to content

Commit

Permalink
Merge pull request #1456 from nitzan-shaked/typing
Browse files Browse the repository at this point in the history
basic typing fixes
  • Loading branch information
willmcgugan authored Jan 6, 2023
2 parents 8ea2eae + 6e9d302 commit 70bded0
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 84 deletions.
11 changes: 6 additions & 5 deletions src/textual/_node_list.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
import sys

from typing import TYPE_CHECKING, Iterator, Sequence, overload
from typing import TYPE_CHECKING, Any, Iterator, Sequence, overload

import rich.repr

Expand All @@ -13,7 +14,7 @@ class DuplicateIds(Exception):


@rich.repr.auto(angular=True)
class NodeList(Sequence):
class NodeList(Sequence["Widget"]):
"""
A container for widgets that forms one level of hierarchy.
Expand Down Expand Up @@ -46,10 +47,10 @@ def __rich_repr__(self) -> rich.repr.Result:
def __len__(self) -> int:
return len(self._nodes)

def __contains__(self, widget: Widget) -> bool:
def __contains__(self, widget: object) -> bool:
return widget in self._nodes

def index(self, widget: Widget) -> int:
def index(self, widget: Any, start: int = 0, stop: int = sys.maxsize) -> int:
"""Return the index of the given widget.
Args:
Expand All @@ -61,7 +62,7 @@ def index(self, widget: Widget) -> int:
Raises:
ValueError: If the widget is not in the node list.
"""
return self._nodes.index(widget)
return self._nodes.index(widget, start, stop)

def _get_by_id(self, widget_id: str) -> Widget | None:
"""Get the widget for the given widget_id, or None if there's no matches in this list"""
Expand Down
5 changes: 2 additions & 3 deletions src/textual/css/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from rich.traceback import Traceback

from ._help_renderables import HelpText
from .tokenize import Token
from .tokenizer import TokenError
from .tokenizer import Token, TokenError


class DeclarationError(Exception):
Expand All @@ -32,7 +31,7 @@ class StyleValueError(ValueError):
error is raised.
"""

def __init__(self, *args, help_text: HelpText | None = None):
def __init__(self, *args: object, help_text: HelpText | None = None):
super().__init__(*args)
self.help_text = help_text

Expand Down
2 changes: 1 addition & 1 deletion src/textual/css/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def skip_to(self, expect: Expect) -> Token:
while True:
if line_no >= len(self.lines):
raise EOFError(
self.path, self.code, line_no, col_no, "Unexpected end of file"
self.path, self.code, (line_no, col_no), "Unexpected end of file"
)
line = self.lines[line_no]
match = expect.search(line, col_no)
Expand Down
8 changes: 2 additions & 6 deletions src/textual/devtools/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ async def _on_startup(app: Application) -> None:
def _run_devtools(verbose: bool, exclude: list[str] | None = None) -> None:
app = _make_devtools_aiohttp_app(verbose=verbose, exclude=exclude)

def noop_print(_: str):
return None
def noop_print(_: str) -> None:
pass

try:
run_app(
Expand Down Expand Up @@ -77,7 +77,3 @@ def _make_devtools_aiohttp_app(
)

return app


if __name__ == "__main__":
_run_devtools()
4 changes: 3 additions & 1 deletion src/textual/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def get_distance_to(self, other: Offset) -> float:
"""
x1, y1 = self
x2, y2 = other
distance = ((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)) ** 0.5
distance: float = ((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)) ** 0.5
return distance


Expand Down Expand Up @@ -217,6 +217,8 @@ def contains_point(self, point: tuple[int, int]) -> bool:

def __contains__(self, other: Any) -> bool:
try:
x: int
y: int
x, y = other
except Exception:
raise TypeError(
Expand Down
13 changes: 6 additions & 7 deletions src/textual/pilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,23 @@
import rich.repr

import asyncio
from typing import TYPE_CHECKING
from typing import Generic

if TYPE_CHECKING:
from .app import App
from .app import App, ReturnType


@rich.repr.auto(angular=True)
class Pilot:
class Pilot(Generic[ReturnType]):
"""Pilot object to drive an app."""

def __init__(self, app: App) -> None:
def __init__(self, app: App[ReturnType]) -> None:
self._app = app

def __rich_repr__(self) -> rich.repr.Result:
yield "app", self._app

@property
def app(self) -> App:
def app(self) -> App[ReturnType]:
"""App: A reference to the application."""
return self._app

Expand All @@ -47,7 +46,7 @@ async def wait_for_animation(self) -> None:
"""Wait for any animation to complete."""
await self._app.animator.wait_for_idle()

async def exit(self, result: object) -> None:
async def exit(self, result: ReturnType) -> None:
"""Exit the app with the given result.
Args:
Expand Down
30 changes: 20 additions & 10 deletions src/textual/renderables/sparkline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import statistics
from typing import Sequence, Iterable, Callable, TypeVar
from typing import Generic, Sequence, Iterable, Callable, TypeVar

from rich.color import Color
from rich.console import ConsoleOptions, Console, RenderResult
Expand All @@ -12,8 +12,10 @@

T = TypeVar("T", int, float)

SummaryFunction = Callable[[Sequence[T]], float]

class Sparkline:

class Sparkline(Generic[T]):
"""A sparkline representing a series of data.
Args:
Expand All @@ -33,16 +35,16 @@ def __init__(
width: int | None,
min_color: Color = Color.from_rgb(0, 255, 0),
max_color: Color = Color.from_rgb(255, 0, 0),
summary_function: Callable[[list[T]], float] = max,
summary_function: SummaryFunction[T] = max,
) -> None:
self.data = data
self.data: Sequence[T] = data
self.width = width
self.min_color = Style.from_color(min_color)
self.max_color = Style.from_color(max_color)
self.summary_function = summary_function
self.summary_function: SummaryFunction[T] = summary_function

@classmethod
def _buckets(cls, data: Sequence[T], num_buckets: int) -> Iterable[list[T]]:
def _buckets(cls, data: Sequence[T], num_buckets: int) -> Iterable[Sequence[T]]:
"""Partition ``data`` into ``num_buckets`` buckets. For example, the data
[1, 2, 3, 4] partitioned into 2 buckets is [[1, 2], [3, 4]].
Expand Down Expand Up @@ -73,13 +75,15 @@ def __rich_console__(
minimum, maximum = min(self.data), max(self.data)
extent = maximum - minimum or 1

buckets = list(self._buckets(self.data, num_buckets=self.width))
buckets = tuple(self._buckets(self.data, num_buckets=width))

bucket_index = 0
bucket_index = 0.0
bars_rendered = 0
step = len(buckets) / width
summary_function = self.summary_function
min_color, max_color = self.min_color.color, self.max_color.color
assert min_color is not None
assert max_color is not None
while bars_rendered < width:
partition = buckets[int(bucket_index)]
partition_summary = summary_function(partition)
Expand All @@ -94,10 +98,16 @@ def __rich_console__(
if __name__ == "__main__":
console = Console()

def last(l):
def last(l: Sequence[T]) -> T:
return l[-1]

funcs = min, max, last, statistics.median, statistics.mean
funcs: Sequence[SummaryFunction[int]] = (
min,
max,
last,
statistics.median,
statistics.mean,
)
nums = [10, 2, 30, 60, 45, 20, 7, 8, 9, 10, 50, 13, 10, 6, 5, 4, 3, 7, 20]
console.print(f"data = {nums}\n")
for f in funcs:
Expand Down
48 changes: 8 additions & 40 deletions src/textual/renderables/text_opacity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Iterable
from typing import Iterable, Tuple, cast

from rich.cells import cell_len
from rich.color import Color
Expand Down Expand Up @@ -62,12 +62,17 @@ def process_segments(
_Segment = Segment
_from_color = Style.from_color
if opacity == 0:
for text, style, control in segments:
for text, style, control in cast(
# use Tuple rather than tuple so Python 3.7 doesn't complain
Iterable[Tuple[str, Style, object]],
segments,
):
invisible_style = _from_color(bgcolor=style.bgcolor)
yield _Segment(cell_len(text) * " ", invisible_style)
else:
for segment in segments:
text, style, control = segment
# use Tuple rather than tuple so Python 3.7 doesn't complain
text, style, control = cast(Tuple[str, Style, object], segment)
if not style:
yield segment
continue
Expand All @@ -85,40 +90,3 @@ def __rich_console__(
) -> RenderResult:
segments = console.render(self.renderable, options)
return self.process_segments(segments, self.opacity)


if __name__ == "__main__":
from rich.live import Live
from rich.panel import Panel
from rich.text import Text

from time import sleep

console = Console()

panel = Panel.fit(
Text("Steak: £30", style="#fcffde on #03761e"),
title="Menu",
style="#ffffff on #000000",
)
console.print(panel)

opacity_panel = TextOpacity(panel, opacity=0.5)
console.print(opacity_panel)

def frange(start, end, step):
current = start
while current < end:
yield current
current += step

while current >= 0:
yield current
current -= step

import itertools

with Live(opacity_panel, refresh_per_second=60) as live:
for value in itertools.cycle(frange(0, 1, 0.05)):
opacity_panel.value = value
sleep(0.05)
4 changes: 2 additions & 2 deletions src/textual/strip.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def __reversed__(self) -> Iterator[Segment]:
def __len__(self) -> int:
return len(self._segments)

def __eq__(self, strip: Strip) -> bool:
return (
def __eq__(self, strip: object) -> bool:
return isinstance(strip, Strip) and (
self._segments == strip._segments and self.cell_length == strip.cell_length
)

Expand Down
20 changes: 11 additions & 9 deletions src/textual/widgets/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@


@dataclass
class _TreeLine:
path: list[TreeNode]
class _TreeLine(Generic[TreeDataType]):
path: list[TreeNode[TreeDataType]]
last: bool

@property
def node(self) -> TreeNode:
def node(self) -> TreeNode[TreeDataType]:
"""TreeNode: The node associated with this line."""
return self.path[-1]

Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
self._label = tree.process_label(label)
self.data = data
self._expanded = expanded
self._children: list[TreeNode] = []
self._children: list[TreeNode[TreeDataType]] = []

self._hover_ = False
self._selected_ = False
Expand Down Expand Up @@ -475,11 +475,11 @@ def clear(self) -> None:
self._updates += 1
self.refresh()

def select_node(self, node: TreeNode | None) -> None:
def select_node(self, node: TreeNode[TreeDataType] | None) -> None:
"""Move the cursor to the given node, or reset cursor.
Args:
node (TreeNode | None): A tree node, or None to reset cursor.
node (TreeNode[TreeDataType] | None): A tree node, or None to reset cursor.
"""
self.cursor_line = -1 if node is None else node._line

Expand Down Expand Up @@ -579,11 +579,11 @@ def scroll_to_line(self, line: int) -> None:
"""
self.scroll_to_region(Region(0, line, self.size.width, 1))

def scroll_to_node(self, node: TreeNode) -> None:
def scroll_to_node(self, node: TreeNode[TreeDataType]) -> None:
"""Scroll to the given node.
Args:
node (TreeNode): Node to scroll in to view.
node (TreeNode[TreeDataType]): Node to scroll in to view.
"""
line = node._line
if line != -1:
Expand Down Expand Up @@ -637,7 +637,9 @@ def _build(self) -> None:

root = self.root

def add_node(path: list[TreeNode], node: TreeNode, last: bool) -> None:
def add_node(
path: list[TreeNode[TreeDataType]], node: TreeNode[TreeDataType], last: bool
) -> None:
child_path = [*path, node]
node._line = len(lines)
add_line(TreeLine(child_path, last))
Expand Down

0 comments on commit 70bded0

Please sign in to comment.