Skip to content

Commit

Permalink
modernize type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
bfontaine committed Oct 7, 2024
1 parent b31e53d commit 7bf70fe
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 31 deletions.
60 changes: 30 additions & 30 deletions clj/seqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import collections.abc as collections_abc
import itertools
import random
from typing import Iterable, TypeVar, Any, Callable, Iterator, Union, Tuple, Dict, Optional, List, Set, cast, Deque
from typing import Iterable, TypeVar, Any, Callable, Iterator, Union, cast, Deque

import clj

Expand Down Expand Up @@ -224,7 +224,7 @@ def butlast(coll: Iterable[T]) -> Iterable[T]:
Return a generator of all but the last item in ``coll``, in linear time.
"""
first_ = True
last_e: Optional[T] = None
last_e: T | None = None
for e in coll:
if first_:
last_e = e
Expand Down Expand Up @@ -303,7 +303,7 @@ def _iter(coll, n=0):
return coll[n:]


def split_at(n: int, coll: Iterable[T]) -> Tuple[Iterable[T], Iterable[T]]:
def split_at(n: int, coll: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]:
"""
Returns a tuple of ``(take(n, coll), drop(n coll))``.
"""
Expand All @@ -325,7 +325,7 @@ def split_at(n: int, coll: Iterable[T]) -> Tuple[Iterable[T], Iterable[T]]:
return taken, _iter(coll, n)


def split_with(pred: Callable[[T], Any], coll: Iterable[T]) -> Tuple[Iterable[T], Iterable[T]]:
def split_with(pred: Callable[[T], Any], coll: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]:
"""
Returns a tuple of ``(take_while(pred, coll), drop_while(pred coll))``.
"""
Expand All @@ -348,7 +348,7 @@ def dropped_while():
return taken, dropped_while()


def replace(smap: Dict[T, T2], coll: Iterable[T]) -> Iterable[Union[T, T2]]:
def replace(smap: dict[T, T2], coll: Iterable[T]) -> Iterable[Union[T, T2]]:
"""
Given a map of replacement pairs and a list/collection, yield a sequence
where any element = a key in ``smap`` replaced with the corresponding val
Expand Down Expand Up @@ -396,7 +396,7 @@ def map_indexed(f: Callable[[int, T], T2], coll: Iterable[T]) -> Iterable[T2]:
return map(lambda pair: f(pair[0], pair[1]), enumerate(coll))


def _first(coll: Iterable[T]) -> Tuple[Optional[T], bool]:
def _first(coll: Iterable[T]) -> tuple[T | None, bool]:
"""
Like first(coll), but return a tuple of ``(first, is_empty)`` where `first` is either the first
element of the collection or ``None`` and ``is_empty`` is a boolean that is ``True`` if the collection
Expand All @@ -409,18 +409,18 @@ def _first(coll: Iterable[T]) -> Tuple[Optional[T], bool]:
first_value: Union[T, object] = next(_make_gen(take(1, coll)), _flag)
if first_value is _flag:
return None, True
return cast(Optional[T], first_value), False
return cast(T | None, first_value), False


def first(coll: Iterable[T]) -> Optional[T]:
def first(coll: Iterable[T]) -> T | None:
"""
Returns the first item in the collection. If ``coll`` is empty, returns ``None``.
"""
first_value: Optional[T] = _first(coll)[0]
first_value: T | None = _first(coll)[0]
return first_value


def ffirst(x: Iterable[Iterable[T]]) -> Optional[T]:
def ffirst(x: Iterable[Iterable[T]]) -> T | None:
"""
Same as ``first(first(x))``
"""
Expand All @@ -440,7 +440,7 @@ def nfirst(x: Iterable[Iterable[T]]) -> Iterable[T]:
return rest(f)


def second(coll: Iterable[T]) -> Optional[T]:
def second(coll: Iterable[T]) -> T | None:
"""
Same as ``first(rest(coll))``.
"""
Expand Down Expand Up @@ -473,7 +473,7 @@ def nth(coll: Iterable[T], n: int, not_found: Any = _nil) -> Any:
return not_found


def last(coll: Iterable[T]) -> Optional[T]:
def last(coll: Iterable[T]) -> T | None:
"""
Return the last item in ``coll``, in linear time. Return ``None`` if ``coll`` is empty.
"""
Expand All @@ -483,14 +483,14 @@ def last(coll: Iterable[T]) -> Optional[T]:
return e


def zipmap(keys: Iterable[T], vals: Iterable[T2]) -> Dict[T, T2]:
def zipmap(keys: Iterable[T], vals: Iterable[T2]) -> dict[T, T2]:
"""
Return a ``dict`` with the keys mapped to the corresponding ``vals``.
"""
return dict(zip(keys, vals))


def group_by(f: Callable[[T], T2], coll: Iterable[T]) -> Dict[T2, List[T]]:
def group_by(f: Callable[[T], T2], coll: Iterable[T]) -> dict[T2, list[T]]:
"""
Returns a ``dict`` of the elements of ``coll`` keyed by the result of ``f``
on each element. The value at each key will be a list of the corresponding
Expand All @@ -503,14 +503,14 @@ def group_by(f: Callable[[T], T2], coll: Iterable[T]) -> Dict[T2, List[T]]:
return dict(groups)


def _make_pred(pred: Union[Callable[[T], T2], Set[T]]) -> Callable[[T], Union[T2, bool]]:
def _make_pred(pred: Union[Callable[[T], T2], set[T]]) -> Callable[[T], Union[T2, bool]]:
if isinstance(pred, set):
return lambda x: x in cast(Set[T], pred)
return lambda x: x in cast(set[T], pred)

return pred


def some(pred: Union[Callable[[T], Any], Set[T]], coll: Iterable[T]) -> Optional[T]:
def some(pred: Union[Callable[[T], Any], set[T]], coll: Iterable[T]) -> T | None:
"""
Returns the first logical true value of ``pred(x)`` for any ``x`` in coll,
else ``None``.
Expand All @@ -536,7 +536,7 @@ def is_seq(x):
return isinstance(x, collections_abc.Sequence)


def every(pred: Union[Callable[[T], Any], Set[T]], coll: Iterable[T]) -> bool:
def every(pred: Union[Callable[[T], Any], set[T]], coll: Iterable[T]) -> bool:
"""
Returns ``True`` if ``pred(x)`` is logical true for every ``x`` in
``coll``, else i``False``.
Expand All @@ -550,15 +550,15 @@ def every(pred: Union[Callable[[T], Any], Set[T]], coll: Iterable[T]) -> bool:
return True


def not_every(pred: Union[Callable[[T], Any], Set[T]], coll: Iterable[T]) -> bool:
def not_every(pred: Union[Callable[[T], Any], set[T]], coll: Iterable[T]) -> bool:
"""
Returns ``False`` if ``pred(x)`` is logical true for every ``x`` in
``coll``, else ``True``.
"""
return not every(pred, coll)


def not_any(pred: Union[Callable[[T], Any], Set[T]], coll: Iterable[T]) -> bool:
def not_any(pred: Union[Callable[[T], Any], set[T]], coll: Iterable[T]) -> bool:
"""
Return ``False`` if ``pred(x)`` is logical true for any ``x`` in ``coll``,
else ``True``.
Expand All @@ -581,7 +581,7 @@ def dorun(coll: Iterable) -> None:
return None


def repeatedly(f: Union[Callable[[], T2], int], n: Optional[Union[int, Callable[[], T2]]] = None) \
def repeatedly(f: Union[Callable[[], T2], int], n: Union[int, Callable[[], T2]] | None = None) \
-> Iterable[T2]:
"""
Takes a function of no args, presumably with side effects, and returns an
Expand Down Expand Up @@ -611,7 +611,7 @@ def iterate(f: Callable[[Any], Any], x) -> Iterable:
x = f(x)


def repeat(x: T, n: Optional[int] = None) -> Iterable[T]:
def repeat(x: T, n: int | None = None) -> Iterable[T]:
"""
Returns a generator that indefinitely yields ``x`` (or ``n`` times if ``n`` is supplied).
Expand Down Expand Up @@ -688,7 +688,7 @@ def dedupe(coll: Iterable[T]) -> Iterable[T]:
prev = e


def empty(coll: T) -> Optional[T]:
def empty(coll: T) -> T | None:
"""
Returns an empty collection of the same type as ``coll``, or ``None``.
"""
Expand All @@ -713,8 +713,8 @@ def count(coll: Iterable) -> int:
return n


def partition(coll: Iterable[T], n: int, step: Optional[int] = None, pad: Optional[Iterable[T2]] = None) \
-> Iterator[List[Union[T, T2]]]:
def partition(coll: Iterable[T], n: int, step: int | None = None, pad: Iterable[T2] | None = None) \
-> Iterator[list[Union[T, T2]]]:
"""
Returns a generator of lists of ``n`` items each, at offsets ``step`` apart. If ``step`` is not supplied, defaults
to ``n``, i.e. the partitions do not overlap. If a ``pad`` collection is supplied, use its elements as necessary to
Expand All @@ -733,7 +733,7 @@ def partition(coll: Iterable[T], n: int, step: Optional[int] = None, pad: Option
# TODO
raise NotImplementedError("Step != n is not supported for now.")

current_partition: List[Union[T, T2]] = []
current_partition: list[Union[T, T2]] = []
partition_index = 0
partition_end = n

Expand All @@ -755,8 +755,8 @@ def partition(coll: Iterable[T], n: int, step: Optional[int] = None, pad: Option
yield current_partition


def partition_by(f: Callable[[T], Any], coll: Iterable[T]) -> Iterable[List[T]]:
current: List[T] = []
def partition_by(f: Callable[[T], Any], coll: Iterable[T]) -> Iterable[list[T]]:
current: list[T] = []
current_value = None
for element in coll:
if not current:
Expand All @@ -777,7 +777,7 @@ def partition_by(f: Callable[[T], Any], coll: Iterable[T]) -> Iterable[List[T]]:
yield current


def seq_gen(coll: Iterable[T]) -> Optional[Iterable[T]]:
def seq_gen(coll: Iterable[T]) -> Iterable[T] | None:
"""
Like Clojure’s ``seq``, but return a lazy iterable that’s equivalent to ``coll`` if not empty.
Expand All @@ -789,5 +789,5 @@ def seq_gen(coll: Iterable[T]) -> Optional[Iterable[T]]:
"""
first_element, _is_empty = _first(coll)
if _is_empty:
return
return None
return clj.concat([first_element], _iter(coll, 1))
2 changes: 1 addition & 1 deletion tests/test_seqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def test_flatten_infinite_generators():


def test_flatten_deep_list():
deep_list = ["foo"]
deep_list: list[Any] = ["foo"]
for _ in range(200):
deep_list = [[[[[deep_list]]]]]

Expand Down

0 comments on commit 7bf70fe

Please sign in to comment.