diff --git a/toolz/itertoolz.py b/toolz/itertoolz.py index c187193f..efbdd57e 100644 --- a/toolz/itertoolz.py +++ b/toolz/itertoolz.py @@ -13,9 +13,6 @@ 'sliding_window', 'partition', 'partition_all', 'count', 'pluck') -identity = lambda x: x - - def remove(predicate, seq): """ Return those items of sequence for which predicate(item) is False @@ -120,7 +117,9 @@ def _merge_sorted_key(seqs, key): heapq.heapify(pq) # Repeatedly yield and then repopulate from the same iterator - while True: + heapreplace = heapq.heapreplace + heappop = heapq.heappop + while len(pq) > 1: try: while True: # raises IndexError when pq is empty @@ -129,11 +128,15 @@ def _merge_sorted_key(seqs, key): item = next(it) # raises StopIteration when exhausted s[0] = key(item) s[2] = item - heapq.heapreplace(pq, s) # restore heap condition + heapreplace(pq, s) # restore heap condition except StopIteration: - heapq.heappop(pq) # remove empty iterator - except IndexError: - return + heappop(pq) # remove empty iterator + if pq: + # Much faster when only a single iterable remains + _, itnum, item, it = pq[0] + yield item + for item in it: + yield item def interleave(seqs, pass_exceptions=()): @@ -161,7 +164,7 @@ def interleave(seqs, pass_exceptions=()): iters = newiters -def unique(seq, key=identity): +def unique(seq, key=None): """ Return only unique elements of a sequence >>> tuple(unique((1, 2, 3))) @@ -175,11 +178,18 @@ def unique(seq, key=identity): ('cat', 'mouse') """ seen = set() - for item in seq: - tag = key(item) - if tag not in seen: - seen.add(tag) - yield item + seen_add = seen.add + if key is None: + for item in seq: + if item not in seen: + seen_add(item) + yield item + else: # calculate key + for item in seq: + val = key(item) + if val not in seen: + seen_add(val) + yield item def isiterable(x): @@ -214,10 +224,11 @@ def isdistinct(seq): """ if iter(seq) is seq: seen = set() + seen_add = seen.add for item in seq: if item in seen: return False - seen.add(item) + seen_add(item) return True else: return len(seq) == len(set(seq)) @@ -528,9 +539,10 @@ def sliding_window(n, seq): d = collections.deque(itertools.islice(it, n), n) if len(d) != n: raise StopIteration() + d_append = d.append for item in it: yield tuple(d) - d.append(item) + d_append(item) yield tuple(d) diff --git a/toolz/tests/test_itertoolz.py b/toolz/tests/test_itertoolz.py index e8cedbc3..d4045e57 100644 --- a/toolz/tests/test_itertoolz.py +++ b/toolz/tests/test_itertoolz.py @@ -3,7 +3,7 @@ from functools import partial from toolz.itertoolz import (remove, groupby, merge_sorted, concat, concatv, interleave, unique, - identity, isiterable, + isiterable, mapcat, isdistinct, first, second, nth, take, drop, interpose, get, rest, last, cons, frequencies, @@ -14,6 +14,10 @@ from operator import add, mul +def identity(x): + return x + + def iseven(x): return x % 2 == 0 @@ -54,6 +58,7 @@ def test_merge_sorted(): assert ''.join(merge_sorted('abc', 'abc', 'abc', key=ord)) == 'aaabbbccc' assert ''.join(merge_sorted('cba', 'cba', 'cba', key=lambda x: -ord(x))) == 'cccbbbaaa' + assert list(merge_sorted([1], [2, 3, 4], key=identity)) == [1, 2, 3, 4] def test_interleave():