diff --git a/distributed/collections.py b/distributed/collections.py index 0eefa9304f6..4aef7d555e9 100644 --- a/distributed/collections.py +++ b/distributed/collections.py @@ -36,13 +36,26 @@ class HeapSet(MutableSet[T]): arbitrary key function. Ties are broken by oldest first. Values must be compatible with :mod:`weakref`. + + Parameters + ---------- + key: Callable + A function that takes a single element of the collection as a parameter and + returns a sorting key. The key does not need to be hashable and does not need to + support :mod:`weakref`. + + Note + ---- + The key returned for each element should not to change over time. If it does, the + position in the heap won't change, even if the element is re-added, and it *may* not + change even if it's discarded and then re-added later. """ __slots__ = ("key", "_data", "_heap", "_inc", "_sorted") key: Callable[[T], Any] _data: set[T] - _inc: int _heap: list[tuple[Any, int, weakref.ref[T]]] + _inc: int _sorted: bool def __init__(self, *, key: Callable[[T], Any]): @@ -106,7 +119,9 @@ def peek(self) -> T: self._sorted = False def peekn(self, n: int) -> Iterator[T]: - "Iterator over the N smallest elements. This is O(1) for n == 1, O(n*logn) otherwise." + """Iterate over the n smallest elements without removing them. + This is O(1) for n == 1; O(n*logn) otherwise. + """ if n <= 0: return # empty iterator if n == 1: @@ -173,10 +188,12 @@ def sorted(self) -> Iterator[T]: if not self._sorted: self._heap.sort() # A sorted list maintains the heap invariant self._sorted = True + seen = set() for _, _, vref in self._heap: value = vref() - if value in self._data: + if value in self._data and value not in seen: yield value + seen.add(value) def clear(self) -> None: self._data.clear() diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index 0e450f479be..066cf147a33 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -325,3 +325,17 @@ def test_heapset_pickle(): assert len(heap2._heap) < len(heap._heap) while heap: assert heap.pop() == heap2.pop() + + +def test_heapset_sort_duplicate(): + """See https://github.com/dask/distributed/issues/6951""" + heap = HeapSet(key=operator.attrgetter("i")) + c1 = C("x", 1) + c2 = C("2", 2) + + heap.add(c1) + heap.add(c2) + heap.discard(c1) + heap.add(c1) + + assert list(heap.sorted()) == [c1, c2]