From 72aeb80054752821e9d16837a1b97e0beb27aa9f Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Sat, 10 May 2014 09:07:30 -0400 Subject: [PATCH] Faster unique, isdistinct, merge_sorted, and sliding_window. The `key` keyword argument to `unique` was changed from `identity` to `None`. This better matches API elsewhere, and lets us remove `identity` from being redefined in `itertoolz`, which always seemed a little weird. Most of the speed improvements come from avoiding attribute resolution in frequently run code. Attribute resolution (i.e., the "dot" operator) is probably more costly than one would expect. Fortunately, there weren't many places to apply this optimization, so impact on code readability was minimal. `unique` employs another optimization: branching by `key is None` outside the loop (thus requiring two loops). While this violates the DRY principle (and, hence, I would prefer not to do it in general), this is only a few lines of code that remain side-by-side, and the performance increase is worth it. `merge_sorted` is now optimized when only a single iterable remains. This makes it *so* much faster while in this condition. --- toolz/itertoolz.py | 44 ++++++++++++++++++++++------------- toolz/tests/test_itertoolz.py | 7 +++++- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/toolz/itertoolz.py b/toolz/itertoolz.py index c187193f..efbdd57e 100644 --- a/toolz/itertoolz.py +++ b/toolz/itertoolz.py @@ -13,9 +13,6 @@ 'sliding_window', 'partition', 'partition_all', 'count', 'pluck') -identity = lambda x: x - - def remove(predicate, seq): """ Return those items of sequence for which predicate(item) is False @@ -120,7 +117,9 @@ def _merge_sorted_key(seqs, key): heapq.heapify(pq) # Repeatedly yield and then repopulate from the same iterator - while True: + heapreplace = heapq.heapreplace + heappop = heapq.heappop + while len(pq) > 1: try: while True: # raises IndexError when pq is empty @@ -129,11 +128,15 @@ def _merge_sorted_key(seqs, key): item = next(it) # raises StopIteration when exhausted s[0] = key(item) s[2] = item - heapq.heapreplace(pq, s) # restore heap condition + heapreplace(pq, s) # restore heap condition except StopIteration: - heapq.heappop(pq) # remove empty iterator - except IndexError: - return + heappop(pq) # remove empty iterator + if pq: + # Much faster when only a single iterable remains + _, itnum, item, it = pq[0] + yield item + for item in it: + yield item def interleave(seqs, pass_exceptions=()): @@ -161,7 +164,7 @@ def interleave(seqs, pass_exceptions=()): iters = newiters -def unique(seq, key=identity): +def unique(seq, key=None): """ Return only unique elements of a sequence >>> tuple(unique((1, 2, 3))) @@ -175,11 +178,18 @@ def unique(seq, key=identity): ('cat', 'mouse') """ seen = set() - for item in seq: - tag = key(item) - if tag not in seen: - seen.add(tag) - yield item + seen_add = seen.add + if key is None: + for item in seq: + if item not in seen: + seen_add(item) + yield item + else: # calculate key + for item in seq: + val = key(item) + if val not in seen: + seen_add(val) + yield item def isiterable(x): @@ -214,10 +224,11 @@ def isdistinct(seq): """ if iter(seq) is seq: seen = set() + seen_add = seen.add for item in seq: if item in seen: return False - seen.add(item) + seen_add(item) return True else: return len(seq) == len(set(seq)) @@ -528,9 +539,10 @@ def sliding_window(n, seq): d = collections.deque(itertools.islice(it, n), n) if len(d) != n: raise StopIteration() + d_append = d.append for item in it: yield tuple(d) - d.append(item) + d_append(item) yield tuple(d) diff --git a/toolz/tests/test_itertoolz.py b/toolz/tests/test_itertoolz.py index e8cedbc3..d4045e57 100644 --- a/toolz/tests/test_itertoolz.py +++ b/toolz/tests/test_itertoolz.py @@ -3,7 +3,7 @@ from functools import partial from toolz.itertoolz import (remove, groupby, merge_sorted, concat, concatv, interleave, unique, - identity, isiterable, + isiterable, mapcat, isdistinct, first, second, nth, take, drop, interpose, get, rest, last, cons, frequencies, @@ -14,6 +14,10 @@ from operator import add, mul +def identity(x): + return x + + def iseven(x): return x % 2 == 0 @@ -54,6 +58,7 @@ def test_merge_sorted(): assert ''.join(merge_sorted('abc', 'abc', 'abc', key=ord)) == 'aaabbbccc' assert ''.join(merge_sorted('cba', 'cba', 'cba', key=lambda x: -ord(x))) == 'cccbbbaaa' + assert list(merge_sorted([1], [2, 3, 4], key=identity)) == [1, 2, 3, 4] def test_interleave():