Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster query_one #4950

Merged
merged 12 commits into from
Aug 28, 2024
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
### Added

- Added `DOMNode.check_consume_key` https://github.com/Textualize/textual/pull/4940
- Added `DOMNode.query_exactly_one` https://github.com/Textualize/textual/pull/4950
- Added `SelectorSet.is_simple` https://github.com/Textualize/textual/pull/4950

### Changed

- KeyPanel will show multiple keys if bound to the same action https://github.com/Textualize/textual/pull/4940
- Breaking change: `DOMNode.query_one` will not `raise TooManyMatches` https://github.com/Textualize/textual/pull/4950

## [0.78.0] - 2024-08-27

Expand Down
1 change: 0 additions & 1 deletion docs/guide/queries.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ send_button = self.query_one("#send")

This will retrieve a widget with an ID of `send`, if there is exactly one.
If there are no matching widgets, Textual will raise a [NoMatches][textual.css.query.NoMatches] exception.
If there is more than one match, Textual will raise a [TooManyMatches][textual.css.query.TooManyMatches] exception.

You can also add a second parameter for the expected type, which will ensure that you get the type you are expecting.

Expand Down
21 changes: 15 additions & 6 deletions src/textual/_node_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
if TYPE_CHECKING:
from _typeshed import SupportsRichComparison

from .dom import DOMNode
from .widget import Widget


Expand All @@ -24,7 +25,8 @@ class NodeList(Sequence["Widget"]):
Although named a list, widgets may appear only once, making them more like a set.
"""

def __init__(self) -> None:
def __init__(self, parent: DOMNode | None = None) -> None:
self._parent = parent
# The nodes in the list
self._nodes: list[Widget] = []
self._nodes_set: set[Widget] = set()
Expand Down Expand Up @@ -52,6 +54,13 @@ def __len__(self) -> int:
def __contains__(self, widget: object) -> bool:
return widget in self._nodes

def updated(self) -> None:
"""Mark the nodes as having been updated."""
self._updates += 1
node = self._parent
while node is not None and (node := node._parent) is not None:
node._nodes._updates += 1

def _sort(
self,
*,
Expand All @@ -69,7 +78,7 @@ def _sort(
else:
self._nodes.sort(key=key, reverse=reverse)

self._updates += 1
self.updated()

def index(self, widget: Any, start: int = 0, stop: int = sys.maxsize) -> int:
"""Return the index of the given widget.
Expand Down Expand Up @@ -102,7 +111,7 @@ def _append(self, widget: Widget) -> None:
if widget_id is not None:
self._ensure_unique_id(widget_id)
self._nodes_by_id[widget_id] = widget
self._updates += 1
self.updated()

def _insert(self, index: int, widget: Widget) -> None:
"""Insert a Widget.
Expand All @@ -117,7 +126,7 @@ def _insert(self, index: int, widget: Widget) -> None:
if widget_id is not None:
self._ensure_unique_id(widget_id)
self._nodes_by_id[widget_id] = widget
self._updates += 1
self.updated()

def _ensure_unique_id(self, widget_id: str) -> None:
if widget_id in self._nodes_by_id:
Expand All @@ -141,15 +150,15 @@ def _remove(self, widget: Widget) -> None:
widget_id = widget.id
if widget_id in self._nodes_by_id:
del self._nodes_by_id[widget_id]
self._updates += 1
self.updated()

def _clear(self) -> None:
"""Clear the node list."""
if self._nodes:
self._nodes.clear()
self._nodes_set.clear()
self._nodes_by_id.clear()
self._updates += 1
self.updated()

def __iter__(self) -> Iterator[Widget]:
return iter(self._nodes)
Expand Down
9 changes: 9 additions & 0 deletions src/textual/css/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ def __post_init__(self) -> None:
def css(self) -> str:
return RuleSet._selector_to_css(self.selectors)

@property
def is_simple(self) -> bool:
"""Are all the selectors simple (i.e. only dependent on static DOM state)."""
simple_types = {SelectorType.ID, SelectorType.TYPE}
return all(
(selector.type in simple_types and not selector.pseudo_classes)
for selector in self.selectors
)

def __rich_repr__(self) -> rich.repr.Result:
selectors = RuleSet._selector_to_css(self.selectors)
yield selectors
Expand Down
118 changes: 109 additions & 9 deletions src/textual/dom.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from rich.text import Text
from rich.tree import Tree

from textual.cache import LRUCache

from ._context import NoActiveAppError, active_message_pump
from ._node_list import NodeList
from ._types import WatchCallbackType
Expand All @@ -37,7 +39,9 @@
from .css._error_tools import friendly_list
from .css.constants import VALID_DISPLAY, VALID_VISIBILITY
from .css.errors import DeclarationError, StyleValueError
from .css.parse import parse_declarations
from .css.match import match
from .css.parse import parse_declarations, parse_selectors
from .css.query import NoMatches, TooManyMatches
from .css.styles import RenderStyles, Styles
from .css.tokenize import IDENTIFIER
from .message_pump import MessagePump
Expand All @@ -60,7 +64,7 @@
from .worker import Worker, WorkType, ResultType

# Unused & ignored imports are needed for the docs to link to these objects:
from .css.query import NoMatches, TooManyMatches, WrongType # type: ignore # noqa: F401
from .css.query import WrongType # type: ignore # noqa: F401

from typing_extensions import Literal

Expand All @@ -74,6 +78,10 @@
ReactiveType = TypeVar("ReactiveType")


QueryOneCacheKey: TypeAlias = "tuple[int, str, Type[Widget] | None]"
"""The key used to cache query_one results."""


class BadIdentifier(Exception):
"""Exception raised if you supply a `id` attribute or class name in the wrong format."""

Expand Down Expand Up @@ -184,13 +192,14 @@ def __init__(
self._name = name
self._id = None
if id is not None:
self.id = id
check_identifiers("id", id)
self._id = id

_classes = classes.split() if classes else []
check_identifiers("class name", *_classes)
self._classes.update(_classes)

self._nodes: NodeList = NodeList()
self._nodes: NodeList = NodeList(self)
self._css_styles: Styles = Styles(self)
self._inline_styles: Styles = Styles(self)
self.styles: RenderStyles = RenderStyles(
Expand All @@ -213,6 +222,8 @@ def __init__(
dict[str, tuple[MessagePump, Reactive | object]] | None
) = None
self._pruning = False
self._query_one_cache: LRUCache[QueryOneCacheKey, DOMNode] = LRUCache(1024)

super().__init__()

def set_reactive(
Expand Down Expand Up @@ -741,7 +752,7 @@ def id(self, new_id: str) -> str:
ValueError: If the ID has already been set.
"""
check_identifiers("id", new_id)

self._nodes.updated()
if self._id is not None:
raise ValueError(
f"Node 'id' attribute may not be changed once set (current id={self._id!r})"
Expand Down Expand Up @@ -1393,21 +1404,110 @@ def query_one(
Raises:
WrongType: If the wrong type was found.
NoMatches: If no node matches the query.
TooManyMatches: If there is more than one matching node in the query.

Returns:
A widget matching the selector.
"""
_rich_traceback_omit = True
from .css.query import DOMQuery

if isinstance(selector, str):
query_selector = selector
else:
query_selector = selector.__name__
query: DOMQuery[Widget] = DOMQuery(self, filter=query_selector)

return query.only_one() if expect_type is None else query.only_one(expect_type)
selector_set = parse_selectors(query_selector)

if all(selectors.is_simple for selectors in selector_set):
cache_key = (self._nodes._updates, query_selector, expect_type)
cached_result = self._query_one_cache.get(cache_key)
if cached_result is not None:
return cached_result
else:
cache_key = None

for node in walk_depth_first(self, with_root=False):
if not match(selector_set, node):
continue
if expect_type is not None and not isinstance(node, expect_type):
continue
if cache_key is not None:
self._query_one_cache[cache_key] = node
return node

raise NoMatches(f"No nodes match {selector!r} on {self!r}")

if TYPE_CHECKING:

@overload
def query_exactly_one(self, selector: str) -> Widget: ...

@overload
def query_exactly_one(self, selector: type[QueryType]) -> QueryType: ...

@overload
def query_exactly_one(
self, selector: str, expect_type: type[QueryType]
) -> QueryType: ...

def query_exactly_one(
self,
selector: str | type[QueryType],
expect_type: type[QueryType] | None = None,
) -> QueryType | Widget:
"""Get a widget from this widget's children that matches a selector or widget type.

!!! Note
This method is similar to [query_one][textual.dom.DOMNode.query_one].
The only difference is that it will raise `TooManyMatches` if there is more than a single match.

Args:
selector: A selector or widget type.
expect_type: Require the object be of the supplied type, or None for any type.

Raises:
WrongType: If the wrong type was found.
NoMatches: If no node matches the query.
TooManyMatches: If there is more than one matching node in the query (and `exactly_one==True`).

Returns:
A widget matching the selector.
"""
_rich_traceback_omit = True

if isinstance(selector, str):
query_selector = selector
else:
query_selector = selector.__name__

selector_set = parse_selectors(query_selector)

if all(selectors.is_simple for selectors in selector_set):
cache_key = (self._nodes._updates, query_selector, expect_type)
cached_result = self._query_one_cache.get(cache_key)
if cached_result is not None:
return cached_result
else:
cache_key = None

children = walk_depth_first(self, with_root=False)
iter_children = iter(children)
for node in iter_children:
if not match(selector_set, node):
continue
if expect_type is not None and not isinstance(node, expect_type):
continue
for later_node in iter_children:
if match(selector_set, later_node):
if expect_type is not None and not isinstance(node, expect_type):
continue
raise TooManyMatches(
"Call to query_one resulted in more than one matched node"
)
if cache_key is not None:
self._query_one_cache[cache_key] = node
return node

raise NoMatches(f"No nodes match {selector!r} on {self!r}")

def set_styles(self, css: str | None = None, **update_styles: Any) -> Self:
"""Set custom styles on this object.
Expand Down
26 changes: 9 additions & 17 deletions src/textual/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
from .renderables.blank import Blank
from .rlock import RLock
from .strip import Strip
from .walk import walk_depth_first

if TYPE_CHECKING:
from .app import App, ComposeResult
Expand Down Expand Up @@ -807,21 +806,14 @@ def get_widget_by_id(
NoMatches: if no children could be found for this ID.
WrongType: if the wrong type was found.
"""
# We use Widget as a filter_type so that the inferred type of child is Widget.
for child in walk_depth_first(self, filter_type=Widget):
try:
if expect_type is None:
return child.get_child_by_id(id)
else:
return child.get_child_by_id(id, expect_type=expect_type)
except NoMatches:
pass
except WrongType as exc:
raise WrongType(
f"Descendant with id={id!r} is wrong type; expected {expect_type},"
f" got {type(child)}"
) from exc
raise NoMatches(f"No descendant found with id={id!r}")

widget = self.query_one(f"#{id}")
if expect_type is not None and not isinstance(widget, expect_type):
raise WrongType(
f"Descendant with id={id!r} is wrong type; expected {expect_type},"
f" got {type(widget)}"
)
return widget

def get_child_by_type(self, expect_type: type[ExpectType]) -> ExpectType:
"""Get the first immediate child of a given type.
Expand Down Expand Up @@ -958,7 +950,7 @@ def _find_mount_point(self, spot: int | str | "Widget") -> tuple["Widget", int]:
# can be passed to query_one. So let's use that to get a widget to
# work on.
if isinstance(spot, str):
spot = self.query_one(spot, Widget)
spot = self.query_exactly_one(spot, Widget)

# At this point we should have a widget, either because we got given
# one, or because we pulled one out of the query. First off, does it
Expand Down
2 changes: 1 addition & 1 deletion tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class App(Widget):
assert app.query_one("#widget1") == widget1
assert app.query_one("#widget1", Widget) == widget1
with pytest.raises(TooManyMatches):
_ = app.query_one(Widget)
_ = app.query_exactly_one(Widget)

assert app.query("Widget.float")[0] == sidebar
assert app.query("Widget.float")[0:2] == [sidebar, tooltip]
Expand Down
Loading