diff --git a/src/libraries/System.Collections/ref/System.Collections.cs b/src/libraries/System.Collections/ref/System.Collections.cs index b0272046b3841..c12b04c3a3cf0 100644 --- a/src/libraries/System.Collections/ref/System.Collections.cs +++ b/src/libraries/System.Collections/ref/System.Collections.cs @@ -126,6 +126,7 @@ public void EnqueueRange(System.Collections.Generic.IEnumerable<(TElement Elemen public void EnqueueRange(System.Collections.Generic.IEnumerable elements, TPriority priority) { } public int EnsureCapacity(int capacity) { throw null; } public TElement Peek() { throw null; } + public bool Remove(TElement element, [System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TElement removedElement, [System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TPriority priority, System.Collections.Generic.IEqualityComparer? equalityComparer = null) { throw null; } public void TrimExcess() { } public bool TryDequeue([System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TElement element, [System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TPriority priority) { throw null; } public bool TryPeek([System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TElement element, [System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TPriority priority) { throw null; } diff --git a/src/libraries/System.Collections/src/System/Collections/Generic/PriorityQueue.cs b/src/libraries/System.Collections/src/System/Collections/Generic/PriorityQueue.cs index e8d8221641cd2..edc1327b446ca 100644 --- a/src/libraries/System.Collections/src/System/Collections/Generic/PriorityQueue.cs +++ b/src/libraries/System.Collections/src/System/Collections/Generic/PriorityQueue.cs @@ -502,6 +502,59 @@ public void EnqueueRange(IEnumerable elements, TPriority priority) } } + /// + /// Removes the first occurrence that equals the specified parameter. + /// + /// The element to try to remove. + /// The actual element that got removed from the queue. + /// The priority value associated with the removed element. + /// The equality comparer governing element equality. + /// if matching entry was found and removed, otherwise. + /// + /// The method performs a linear-time scan of every element in the heap, removing the first value found to match the parameter. + /// In case of duplicate entries, what entry does get removed is non-deterministic and does not take priority into account. + /// + /// If no is specified, will be used instead. + /// + public bool Remove( + TElement element, + [MaybeNullWhen(false)] out TElement removedElement, + [MaybeNullWhen(false)] out TPriority priority, + IEqualityComparer? equalityComparer = null) + { + int index = FindIndex(element, equalityComparer); + if (index < 0) + { + removedElement = default; + priority = default; + return false; + } + + (TElement Element, TPriority Priority)[] nodes = _nodes; + (removedElement, priority) = nodes[index]; + int newSize = --_size; + + if (index < newSize) + { + // We're removing an element from the middle of the heap. + // Pop the last element in the collection and sift downward from the removed index. + (TElement Element, TPriority Priority) lastNode = nodes[newSize]; + + if (_comparer == null) + { + MoveDownDefaultComparer(lastNode, index); + } + else + { + MoveDownCustomComparer(lastNode, index); + } + } + + nodes[newSize] = default; + _version++; + return true; + } + /// /// Removes all items from the . /// @@ -809,6 +862,41 @@ private void MoveDownCustomComparer((TElement Element, TPriority Priority) node, nodes[nodeIndex] = node; } + /// + /// Scans the heap for the first index containing an element equal to the specified parameter. + /// + private int FindIndex(TElement element, IEqualityComparer? equalityComparer) + { + equalityComparer ??= EqualityComparer.Default; + ReadOnlySpan<(TElement Element, TPriority Priority)> nodes = _nodes.AsSpan(0, _size); + + // Currently the JIT doesn't optimize direct EqualityComparer.Default.Equals + // calls for reference types, so we want to cache the comparer instance instead. + // TODO https://github.com/dotnet/runtime/issues/10050: Update if this changes in the future. + if (typeof(TElement).IsValueType && equalityComparer == EqualityComparer.Default) + { + for (int i = 0; i < nodes.Length; i++) + { + if (EqualityComparer.Default.Equals(element, nodes[i].Element)) + { + return i; + } + } + } + else + { + for (int i = 0; i < nodes.Length; i++) + { + if (equalityComparer.Equals(element, nodes[i].Element)) + { + return i; + } + } + } + + return -1; + } + /// /// Initializes the custom comparer to be used internally by the heap. /// diff --git a/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Generic.Tests.cs b/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Generic.Tests.cs index c50dfa977d045..0f50f9183f5cb 100644 --- a/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Generic.Tests.cs +++ b/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Generic.Tests.cs @@ -93,7 +93,7 @@ public void PriorityQueue_EnumerableConstructor_ShouldContainAllElements(int cou #endregion - #region Enqueue, Dequeue, Peek, EnqueueDequeue, DequeueEnqueue + #region Enqueue, Dequeue, Peek, EnqueueDequeue, DequeueEnqueue, Remove [Theory] [MemberData(nameof(ValidCollectionSizes))] @@ -246,6 +246,35 @@ public void PriorityQueue_DequeueEnqueue(int count) AssertExtensions.CollectionEqual(expectedItems, queue.UnorderedItems, EqualityComparer<(TElement, TPriority)>.Default); } + [Theory] + [MemberData(nameof(ValidCollectionSizes))] + public void PriorityQueue_Remove_AllElements(int count) + { + bool result; + TElement removedElement; + TPriority removedPriority; + + PriorityQueue queue = CreatePriorityQueue(count, count, out List<(TElement element, TPriority priority)> generatedItems); + + for (int i = count - 1; i >= 0; i--) + { + (TElement element, TPriority priority) = generatedItems[i]; + + result = queue.Remove(element, out removedElement, out removedPriority); + + Assert.True(result); + Assert.Equal(element, removedElement); + Assert.Equal(priority, removedPriority); + Assert.Equal(i, queue.Count); + } + + result = queue.Remove(default, out removedElement, out removedPriority); + + Assert.False(result); + Assert.Equal(default, removedElement); + Assert.Equal(default, removedPriority); + } + #endregion #region Clear diff --git a/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Tests.Dijkstra.cs b/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Tests.Dijkstra.cs new file mode 100644 index 0000000000000..08155e493ee66 --- /dev/null +++ b/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Tests.Dijkstra.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Linq; +using Xunit; +using NodeId = int; +using Distance = int; + +namespace System.Collections.Tests +{ + public partial class PriorityQueue_NonGeneric_Tests + { + public record struct Graph(Edge[][] nodes); + public record struct Edge(NodeId neighbor, Distance weight); + + [Fact] + public static void PriorityQueue_DijkstraSmokeTest() + { + var graph = new Graph([ + [new Edge(1, 7), new Edge(2, 9), new Edge(5, 14)], + [new Edge(0, 7), new Edge(2, 10), new Edge(3, 15)], + [new Edge(0, 9), new Edge(1, 10), new Edge(3, 11), new Edge(5, 2)], + [new Edge(1, 15), new Edge(2, 11), new Edge(4, 6)], + [new Edge(3, 6), new Edge(5, 9)], + [new Edge(0, 14), new Edge(2, 2), new Edge(4, 9)], + ]); + + NodeId startNode = 0; + + (NodeId node, Distance distance)[] expectedDistances = + [ + (0, 0), + (1, 7), + (2, 9), + (3, 20), + (4, 20), + (5, 11), + ]; + + (NodeId node, Distance distance)[] actualDistances = RunDijkstra(graph, startNode); + + Assert.Equal(expectedDistances, actualDistances); + } + + public static (NodeId node, Distance distance)[] RunDijkstra(Graph graph, NodeId startNode) + { + Distance[] distances = Enumerable.Repeat(int.MaxValue, graph.nodes.Length).ToArray(); + var queue = new PriorityQueue(); + + distances[startNode] = 0; + queue.Enqueue(startNode, 0); + + do + { + NodeId nodeId = queue.Dequeue(); + Distance nodeDistance = distances[nodeId]; + + foreach (Edge edge in graph.nodes[nodeId]) + { + Distance distance = distances[edge.neighbor]; + Distance newDistance = nodeDistance + edge.weight; + if (newDistance < distance) + { + distances[edge.neighbor] = newDistance; + // Simulate priority update by attempting to remove the entry + // before re-inserting it with the new distance. + queue.Remove(edge.neighbor, out _, out _); + queue.Enqueue(edge.neighbor, newDistance); + } + } + } + while (queue.Count > 0); + + return distances.Select((distance, nodeId) => (nodeId, distance)).ToArray(); + } + } +} diff --git a/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Tests.cs b/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Tests.cs index 0bd6a70a8f2a5..03419ab9f6af2 100644 --- a/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Tests.cs +++ b/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Tests.cs @@ -8,7 +8,7 @@ namespace System.Collections.Tests { - public class PriorityQueue_NonGeneric_Tests : TestBase + public partial class PriorityQueue_NonGeneric_Tests : TestBase { protected PriorityQueue CreateSmallPriorityQueue(out HashSet<(string, int)> items) { @@ -167,6 +167,55 @@ public void PriorityQueue_Generic_EnqueueRange_Null() Assert.Equal("not null", queue.Dequeue()); } + [Fact] + public void PriorityQueue_Generic_Remove_MatchingElement() + { + PriorityQueue queue = new PriorityQueue(); + queue.EnqueueRange([("value0", 0), ("value1", 1), ("value2", 2)]); + + Assert.True(queue.Remove("value1", out string removedElement, out int removedPriority)); + Assert.Equal("value1", removedElement); + Assert.Equal(1, removedPriority); + Assert.Equal(2, queue.Count); + } + + [Fact] + public void PriorityQueue_Generic_Remove_MismatchElement() + { + PriorityQueue queue = new PriorityQueue(); + queue.EnqueueRange([("value0", 0), ("value1", 1), ("value2", 2)]); + + Assert.False(queue.Remove("value4", out string removedElement, out int removedPriority)); + Assert.Null(removedElement); + Assert.Equal(0, removedPriority); + Assert.Equal(3, queue.Count); + } + + [Fact] + public void PriorityQueue_Generic_Remove_DuplicateElement() + { + PriorityQueue queue = new PriorityQueue(); + queue.EnqueueRange([("value0", 0), ("value1", 1), ("value0", 2)]); + + Assert.True(queue.Remove("value0", out string removedElement, out int removedPriority)); + Assert.Equal("value0", removedElement); + Assert.True(removedPriority is 0 or 2); + Assert.Equal(2, queue.Count); + } + + [Fact] + public void PriorityQueue_Generic_Remove_CustomEqualityComparer() + { + PriorityQueue queue = new PriorityQueue(); + queue.EnqueueRange([("value0", 0), ("value1", 1), ("value2", 2)]); + EqualityComparer equalityComparer = EqualityComparer.Create((left, right) => left[^1] == right[^1]); + + Assert.True(queue.Remove("someOtherValue1", out string removedElement, out int removedPriority, equalityComparer)); + Assert.Equal("value1", removedElement); + Assert.Equal(1, removedPriority); + Assert.Equal(2, queue.Count); + } + [Fact] public void PriorityQueue_Constructor_int_Negative_ThrowsArgumentOutOfRangeException() { @@ -207,6 +256,16 @@ public void PriorityQueue_EmptyCollection_Peek_ShouldReturnFalse() Assert.Throws(() => queue.Peek()); } + [Fact] + public void PriorityQueue_EmptyCollection_Remove_ShouldReturnFalse() + { + var queue = new PriorityQueue(); + + Assert.False(queue.Remove(element: "element", out string removedElement, out string removedPriority)); + Assert.Null(removedElement); + Assert.Null(removedPriority); + } + #region EnsureCapacity, TrimExcess [Fact] diff --git a/src/libraries/System.Collections/tests/System.Collections.Tests.csproj b/src/libraries/System.Collections/tests/System.Collections.Tests.csproj index 3d64df4f13ee5..1f45aa68f98f6 100644 --- a/src/libraries/System.Collections/tests/System.Collections.Tests.csproj +++ b/src/libraries/System.Collections/tests/System.Collections.Tests.csproj @@ -1,4 +1,4 @@ - + $(NetCoreAppCurrent) true @@ -106,6 +106,7 @@ +