Skip to content

Commit

Permalink
Merge pull request #25 from kbsriram/add-types-3
Browse files Browse the repository at this point in the history
Update type annotations for itertools extras.
  • Loading branch information
FoamyGuy authored Apr 29, 2024
2 parents 750de7a + 3bd2dd9 commit a77ae43
Show file tree
Hide file tree
Showing 3 changed files with 356 additions and 25 deletions.
93 changes: 68 additions & 25 deletions adafruit_itertools/adafruit_itertools_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,54 @@

import adafruit_itertools as it

try:
from typing import (
Any,
Callable,
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from typing_extensions import TypeAlias

_T = TypeVar("_T")
_N: TypeAlias = Union[int, float, complex]
_Predicate: TypeAlias = Callable[[_T], bool]
except ImportError:
pass


__version__ = "0.0.0+auto.0"
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Itertools.git"


def all_equal(iterable):
def all_equal(iterable: Iterable[Any]) -> bool:
"""Returns True if all the elements are equal to each other.
:param iterable: source of values
"""
g = it.groupby(iterable)
next(g) # should succeed, value isn't relevant
try:
next(g) # should fail: only 1 group
next(g) # value isn't relevant
except StopIteration:
# Empty iterable, return True to match cpython behavior.
return True
try:
next(g)
# more than one group, so we have different elements.
return False
except StopIteration:
# Only one group - all elements must be equal.
return True


def dotproduct(vec1, vec2):
def dotproduct(vec1: Iterable[_N], vec2: Iterable[_N]) -> _N:
"""Compute the dot product of two vectors.
:param vec1: the first vector
Expand All @@ -71,7 +99,11 @@ def dotproduct(vec1, vec2):
return sum(map(lambda x, y: x * y, vec1, vec2))


def first_true(iterable, default=False, pred=None):
def first_true(
iterable: Iterable[_T],
default: Union[bool, _T] = False,
pred: Optional[_Predicate[_T]] = None,
) -> Union[bool, _T]:
"""Returns the first true value in the iterable.
If no true value is found, returns *default*
Expand All @@ -94,7 +126,7 @@ def first_true(iterable, default=False, pred=None):
return default


def flatten(iterable_of_iterables):
def flatten(iterable_of_iterables: Iterable[Iterable[_T]]) -> Iterator[_T]:
"""Flatten one level of nesting.
:param iterable_of_iterables: a sequence of iterables to flatten
Expand All @@ -104,7 +136,9 @@ def flatten(iterable_of_iterables):
return it.chain_from_iterable(iterable_of_iterables)


def grouper(iterable, n, fillvalue=None):
def grouper(
iterable: Iterable[_T], n: int, fillvalue: Optional[_T] = None
) -> Iterator[Tuple[_T, ...]]:
"""Collect data into fixed-length chunks or blocks.
:param iterable: source of values
Expand All @@ -118,7 +152,7 @@ def grouper(iterable, n, fillvalue=None):
return it.zip_longest(*args, fillvalue=fillvalue)


def iter_except(func, exception):
def iter_except(func: Callable[[], _T], exception: Type[BaseException]) -> Iterator[_T]:
"""Call a function repeatedly, yielding the results, until exception is raised.
Converts a call-until-exception interface to an iterator interface.
Expand All @@ -143,7 +177,7 @@ def iter_except(func, exception):
pass


def ncycles(iterable, n):
def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]:
"""Returns the sequence elements a number of times.
:param iterable: the source of values
Expand All @@ -153,7 +187,7 @@ def ncycles(iterable, n):
return it.chain_from_iterable(it.repeat(tuple(iterable), n))


def nth(iterable, n, default=None):
def nth(iterable: Iterable[_T], n: int, default: Optional[_T] = None) -> Optional[_T]:
"""Returns the nth item or a default value.
:param iterable: the source of values
Expand All @@ -166,7 +200,7 @@ def nth(iterable, n, default=None):
return default


def padnone(iterable):
def padnone(iterable: Iterable[_T]) -> Iterator[Optional[_T]]:
"""Returns the sequence elements and then returns None indefinitely.
Useful for emulating the behavior of the built-in map() function.
Expand All @@ -177,13 +211,17 @@ def padnone(iterable):
return it.chain(iterable, it.repeat(None))


def pairwise(iterable):
"""Pair up valuesin the iterable.
def pairwise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T]]:
"""Return successive overlapping pairs from the iterable.
The number of tuples from the output will be one fewer than the
number of values in the input. It will be empty if the input has
fewer than two values.
:param iterable: source of values
"""
# pairwise(range(11)) -> (1, 2), (3, 4), (5, 6), (7, 8), (9, 10)
# pairwise(range(5)) -> (0, 1), (1, 2), (2, 3), (3, 4)
a, b = it.tee(iterable)
try:
next(b)
Expand All @@ -192,7 +230,9 @@ def pairwise(iterable):
return zip(a, b)


def partition(pred, iterable):
def partition(
pred: _Predicate[_T], iterable: Iterable[_T]
) -> Tuple[Iterator[_T], Iterator[_T]]:
"""Use a predicate to partition entries into false entries and true entries.
:param pred: the predicate that divides the values
Expand All @@ -204,7 +244,7 @@ def partition(pred, iterable):
return it.filterfalse(pred, t1), filter(pred, t2)


def prepend(value, iterator):
def prepend(value: _T, iterator: Iterable[_T]) -> Iterator[_T]:
"""Prepend a single value in front of an iterator
:param value: the value to prepend
Expand All @@ -215,7 +255,7 @@ def prepend(value, iterator):
return it.chain([value], iterator)


def quantify(iterable, pred=bool):
def quantify(iterable: Iterable[_T], pred: _Predicate[_T] = bool) -> int:
"""Count how many times the predicate is true.
:param iterable: source of values
Expand All @@ -227,7 +267,9 @@ def quantify(iterable, pred=bool):
return sum(map(pred, iterable))


def repeatfunc(func, times=None, *args):
def repeatfunc(
func: Callable[..., _T], times: Optional[int] = None, *args: Any
) -> Iterator[_T]:
"""Repeat calls to func with specified arguments.
Example: repeatfunc(random.random)
Expand All @@ -242,7 +284,7 @@ def repeatfunc(func, times=None, *args):
return it.starmap(func, it.repeat(args, times))


def roundrobin(*iterables):
def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]:
"""Return an iterable created by repeatedly picking value from each
argument in order.
Expand All @@ -263,18 +305,19 @@ def roundrobin(*iterables):
nexts = it.cycle(it.islice(nexts, num_active))


def tabulate(function, start=0):
"""Apply a function to a sequence of consecutive integers.
def tabulate(function: Callable[[int], int], start: int = 0) -> Iterator[int]:
"""Apply a function to a sequence of consecutive numbers.
:param function: the function of one integer argument
:param function: the function of one numeric argument.
:param start: optional value to start at (default is 0)
"""
# take(5, tabulate(lambda x: x * x))) -> 0 1 4 9 16
return map(function, it.count(start))
counter: Iterator[int] = it.count(start) # type: ignore[assignment]
return map(function, counter)


def tail(n, iterable):
def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]:
"""Return an iterator over the last n items
:param n: how many values to return
Expand All @@ -294,7 +337,7 @@ def tail(n, iterable):
return iter(buf)


def take(n, iterable):
def take(n: int, iterable: Iterable[_T]) -> List[_T]:
"""Return first n items of the iterable as a list
:param n: how many values to take
Expand Down
3 changes: 3 additions & 0 deletions optional_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# SPDX-FileCopyrightText: 2022 Alec Delaney, for Adafruit Industries
#
# SPDX-License-Identifier: Unlicense

# For comparison when running tests
more-itertools
Loading

0 comments on commit a77ae43

Please sign in to comment.