Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(type hints): various improvements to type hints #6862

Closed
wants to merge 8 commits into from
1 change: 1 addition & 0 deletions ibis/backends/dask/aggcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def sliced_agg(s):

class Window(AggregationContext):
__slots__ = ("construct_window",)
construct_window: operator.methodcaller

def __init__(self, kind, *args, **kwargs):
super().__init__(
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/impala/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _parse_storage_desc_params(self):
params = self._parse_nested_params(self._storage_param_cleaners)
self.storage["Desc Params"] = params

_storage_param_cleaners = {}
_storage_param_cleaners: dict = {}

def _parse_nested_params(self, cleaners):
import pandas as pd
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/mysql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class _FieldFlags:
NUM = 1 << 15

__slots__ = ("value",)
value: int

def __init__(self, value: int) -> None:
self.value = value
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/pandas/aggcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ def window_agg_udf(

class Window(AggregationContext):
__slots__ = ("construct_window",)
construct_window: operator.methodcaller

def __init__(self, kind, *args, **kwargs):
super().__init__(
Expand Down
4 changes: 4 additions & 0 deletions ibis/common/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class Annotation:
"""

__slots__ = ("_pattern", "_default", "_typehint")
_pattern: Pattern | callable | None
_default: Any
_typehint: Any

def __init__(self, pattern=None, default=EMPTY, typehint=EMPTY):
if pattern is None or isinstance(pattern, Pattern):
Expand Down Expand Up @@ -127,6 +130,7 @@ class Argument(Annotation):
"""

__slots__ = ("_kind",)
_kind: int

def __init__(
self,
Expand Down
3 changes: 2 additions & 1 deletion ibis/common/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,9 @@ class Slotted(Base):
"""

__slots__ = ("__precomputed_hash__",)
__precomputed_hash__: int

def __init__(self, **kwargs) -> Self:
def __init__(self, **kwargs) -> None:
for name, value in kwargs.items():
object.__setattr__(self, name, value)
hashvalue = hash(tuple(kwargs.values()))
Expand Down
7 changes: 4 additions & 3 deletions ibis/common/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

def memoize(func: Callable) -> Callable:
"""Memoize a function."""
cache = {}
cache: dict = {}

@functools.wraps(func)
def wrapper(*args, **kwargs):
Expand All @@ -33,6 +33,7 @@ def wrapper(*args, **kwargs):

class WeakCache(MutableMapping):
__slots__ = ("_data",)
_data: dict

def __init__(self):
object.__setattr__(self, "_data", {})
Expand Down Expand Up @@ -94,11 +95,11 @@ def __init__(
key: Callable[[Any], Any],
) -> None:
self.cache = bidict()
self.refs = Counter()
self.refs: Counter = Counter()
self.populate = populate
self.lookup = lookup
self.finalize = finalize
self.names = defaultdict(generate_name)
self.names: defaultdict = defaultdict(generate_name)
self.key = key or (lambda x: x)

def get(self, key, default=None):
Expand Down
3 changes: 3 additions & 0 deletions ibis/common/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ class FrozenDict(Mapping[K, V], Hashable):
"""Immutable dictionary with a precomputed hash value."""

__slots__ = ("__view__", "__precomputed_hash__")
__view__: MappingProxyType
__precomputed_hash__: int

def __init__(self, *args, **kwargs):
dictview = MappingProxyType(dict(*args, **kwargs))
Expand Down Expand Up @@ -215,4 +217,5 @@ def checkpoint(self):
self._iterator, self._checkpoint = tee(self._iterator)


frozendict: FrozenDict
public(frozendict=FrozenDict)
19 changes: 19 additions & 0 deletions ibis/common/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import itertools
import math
from typing import (
TYPE_CHECKING,
Any,
Hashable,
Iterable,
Expand All @@ -16,6 +17,9 @@
from ibis.common.graph import Node
from ibis.util import promote_list

if TYPE_CHECKING:
from collections.abc import Callable

K = TypeVar("K", bound=Hashable)


Expand Down Expand Up @@ -56,6 +60,8 @@ class DisjointSet(Mapping[K, Set[K]]):
"""

__slots__ = ("_parents", "_classes")
_parents: dict
_classes: dict

def __init__(self, data: Iterable[K] | None = None):
self._parents = {}
Expand Down Expand Up @@ -248,6 +254,7 @@ class Slotted:
"""

__slots__ = ("__precomputed_hash__",)
__precomputed_hash__: int

def __init__(self, *args):
for name, value in itertools.zip_longest(self.__slots__, args):
Expand Down Expand Up @@ -281,6 +288,7 @@ class Variable(Slotted):
"""

__slots__ = ("name",)
name: str

def __init__(self, name: str):
if name is None:
Expand Down Expand Up @@ -329,6 +337,9 @@ class Pattern(Slotted):
"""

__slots__ = ("head", "args", "name")
head: type
args: tuple
name: str | None

# TODO(kszucs): consider to raise if the pattern matches none
def __init__(self, head, args, name=None, conditions=None):
Expand Down Expand Up @@ -442,6 +453,7 @@ class DynamicApplier(Slotted):
"""A dynamic applier which calls a function to compute the result."""

__slots__ = ("func",)
func: Callable

def substitute(self, egraph, enode, subst):
kwargs = {k: v for k, v in subst.items() if isinstance(k, str)}
Expand All @@ -455,6 +467,8 @@ class Rewrite(Slotted):
"""A rewrite rule which matches a pattern and applies a pattern or a function."""

__slots__ = ("matcher", "applier")
matcher: Pattern
applier: Callable | Pattern | Variable

def __init__(self, matcher, applier):
if callable(applier):
Expand All @@ -481,6 +495,8 @@ class ENode(Slotted, Node):
"""

__slots__ = ("head", "args")
head: type
args: tuple

def __init__(self, head, args):
# TODO(kszucs): ensure that it is a ground term, this check should be removed
Expand Down Expand Up @@ -529,6 +545,9 @@ def mapper(node, _, **kwargs):

class EGraph:
__slots__ = ("_nodes", "_etables", "_eclasses")
_nodes: dict
_etables: collections.defaultdict
_eclasses: DisjointSet

def __init__(self):
# store the nodes before converting them to enodes, so we can spare the initial
Expand Down
Loading