Skip to content

Commit

Permalink
Prevent duplicates in HeapSet.sorted() (dask#6952)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored and gjoseph92 committed Oct 31, 2022
1 parent 1bb3c82 commit a7a8f44
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
23 changes: 20 additions & 3 deletions distributed/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,26 @@ class HeapSet(MutableSet[T]):
arbitrary key function. Ties are broken by oldest first.
Values must be compatible with :mod:`weakref`.
Parameters
----------
key: Callable
A function that takes a single element of the collection as a parameter and
returns a sorting key. The key does not need to be hashable and does not need to
support :mod:`weakref`.
Note
----
The key returned for each element should not to change over time. If it does, the
position in the heap won't change, even if the element is re-added, and it *may* not
change even if it's discarded and then re-added later.
"""

__slots__ = ("key", "_data", "_heap", "_inc", "_sorted")
key: Callable[[T], Any]
_data: set[T]
_inc: int
_heap: list[tuple[Any, int, weakref.ref[T]]]
_inc: int
_sorted: bool

def __init__(self, *, key: Callable[[T], Any]):
Expand Down Expand Up @@ -106,7 +119,9 @@ def peek(self) -> T:
self._sorted = False

def peekn(self, n: int) -> Iterator[T]:
"Iterator over the N smallest elements. This is O(1) for n == 1, O(n*logn) otherwise."
"""Iterate over the n smallest elements without removing them.
This is O(1) for n == 1; O(n*logn) otherwise.
"""
if n <= 0:
return # empty iterator
if n == 1:
Expand Down Expand Up @@ -173,10 +188,12 @@ def sorted(self) -> Iterator[T]:
if not self._sorted:
self._heap.sort() # A sorted list maintains the heap invariant
self._sorted = True
seen = set()
for _, _, vref in self._heap:
value = vref()
if value in self._data:
if value in self._data and value not in seen:
yield value
seen.add(value)

def clear(self) -> None:
self._data.clear()
Expand Down
14 changes: 14 additions & 0 deletions distributed/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,17 @@ def test_heapset_pickle():
assert len(heap2._heap) < len(heap._heap)
while heap:
assert heap.pop() == heap2.pop()


def test_heapset_sort_duplicate():
"""See https://github.com/dask/distributed/issues/6951"""
heap = HeapSet(key=operator.attrgetter("i"))
c1 = C("x", 1)
c2 = C("2", 2)

heap.add(c1)
heap.add(c2)
heap.discard(c1)
heap.add(c1)

assert list(heap.sorted()) == [c1, c2]

0 comments on commit a7a8f44

Please sign in to comment.