Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Share the majority of code between the Binary and Flat-Array versions of our interval trees. #73859

Merged
merged 18 commits into from
Jun 6, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using Microsoft.CodeAnalysis.Collections;
using Microsoft.CodeAnalysis.Collections.Internal;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.Shared.Collections;
Expand All @@ -26,13 +23,6 @@ namespace Microsoft.CodeAnalysis.Shared.Collections;

public static readonly FlatArrayIntervalTree<T> Empty = new(new SegmentedArray<Node>(0));

private static readonly ObjectPool<Stack<(int nodeIndex, bool firstTime)>> s_stackPool = new(() => new(), trimOnFree: false);

/// <summary>
/// Keep around a fair number of these as we often use them in parallel algorithms.
/// </summary>
private static readonly ObjectPool<Stack<int>> s_nodeIndexPool = new(() => new(), 128, trimOnFree: false);

/// <summary>
/// The nodes of this interval tree flatted into a single array. The root is as index 0. The left child of any
/// node at index <c>i</c> is at <c>2*i + 1</c> and the right child is at <c>2*i + 2</c>. If a left/right child
Expand All @@ -46,9 +36,7 @@ namespace Microsoft.CodeAnalysis.Shared.Collections;
private readonly SegmentedArray<Node> _array;

private FlatArrayIntervalTree(SegmentedArray<Node> array)
{
_array = array;
}
=> _array = array;

/// <summary>
/// Provides access to lots of common algorithms on this interval tree.
Expand Down Expand Up @@ -109,8 +97,8 @@ public static FlatArrayIntervalTree<T> CreateFromSorted<TIntrospector>(in TIntro

static void BuildCompleteTreeTop(SegmentedList<T> source, SegmentedArray<Node> destination)
{
// The nature of a complete tree is that the last level always only contains the odd remaining numbers.
// For example, given the initial values a-n:
// The nature of a complete tree is that the last level always only contains the elements at the even
// indices of the original source. For example, given the initial values a-n:
//
// a, b, c, d, e, f, g, h, i, j, k, l, m, n. The final tree will look like:
// h, d, l, b, f, j, n, a, c, e, g, i, k, m. Which corresponds to:
Expand All @@ -123,8 +111,9 @@ static void BuildCompleteTreeTop(SegmentedList<T> source, SegmentedArray<Node> d
// / \ / \ / \ /
// a c e g i k m
//
// Note that the first 3 levels are the even elements of the original list) which end up forming a perfect
// balanced tree, and the odd elements of the original list are the remaining values on the last level.
// Note that the first 3 levels are the elements at the odd indices of the original list) which end up
// forming a perfect balanced tree, and the elements at the even indices of the original list are the
// remaining values on the last level.

// How many levels will be in the perfect binary tree. For the example above, this would be 3.
var level = SegmentedArraySortUtils.Log2((uint)source.Count + 1);
Expand All @@ -146,49 +135,50 @@ static void BuildCompleteTreeTop(SegmentedList<T> source, SegmentedArray<Node> d

// The above loop will do the following over the first few iterations (changes highlighted with *):
//
// Dst: ␀, ␀, ␀, ␀, ␀, ␀, ␀, ␀, ␀, ␀, ␀, ␀, , *m* // m placed at the end of the destination.
// Dst: ∞, ∞, ∞, ∞, ∞, ∞, ∞, ∞, ∞, ∞, ∞, ∞, , *m* // m placed at the end of the destination.
// Src: a, b, c, d, e, f, g, h, i, j, k, l, *l*, n // l moved to where m was in the original source.
//
// Dst: ␀, ␀, ␀, ␀, ␀, ␀, ␀, ␀, ␀, ␀, ␀, , *k*, m // k placed right before m in the destination.
// Dst: ∞, ∞, ∞, ∞, ∞, ∞, ∞, ∞, ∞, ∞, ∞, , *k*, m // k placed right before m in the destination.
// Src: a, b, c, d, e, f, g, h, i, j, k, *j*, l, n // j moved right before where we placed l in the original source.
//
// Dst: ␀, ␀, ␀, ␀, ␀, ␀, ␀, ␀, ␀, ␀, , *i*, k, m // i placed right before k in the destination.
// Dst: ∞, ∞, ∞, ∞, ∞, ∞, ∞, ∞, ∞, ∞, , *i*, k, m // i placed right before k in the destination.
// Src: a, b, c, d, e, f, g, h, i, j, *h*, j, l, n // h moved right before where we placed j in the original source.
//
// Each iteration takes the next odd element from the end of the source list and places it at the
// next available space from the end of the destination array (effectively building the last row of
// the complete binary tree).
// Each iteration takes the next element at an even index from the end of the source list and places
// it at the next available space from the end of the destination array (effectively building the
// last row of the complete binary tree).
//
// It then takes the next even element from the end of the source list and moves it to the next spot
// from the end of the source list. This makes the end of the source-list contain the original even
// elements (up the perfect-complete count of elements), now abutted against each other.
// It then takes the next element at an odd index from the end of the source list and moves it to
// the next spot from the end of the source list. This makes the end of the source-list contain the
// original odd-indexed elements (up the perfect-complete count of elements), now abutted against
// each other.
}

// After this, source will be equal to:
// After the loop above fully completes, source will be equal to:
//
// a, b, c, d, e, f, g - b, d, f, h, j, l, n.
//
// In other words, the last half (after 'g') will be updated to be the even elements from the original
// list. This will be what we'll create the perfect tree from below. We will not look at the elements
// before this in 'source' as they are already either in the correct place in the 'source' *or*
// 'destination' arrays.
// The last half (after 'g') will be updated to be the odd-indexed elements from the original list.
// This will be what we'll create the perfect tree from below. We will not look at the elements before
// this in 'source' as they are already either in the correct place in the 'source' *or* 'destination'
// arrays.
//
// Destination will be equal to:
// ␀, ␀, ␀, ␀, ␀, ␀, ␀, ␀, c, e, g, i, k, m
// ∞, ∞, ∞, ∞, ∞, ∞, ∞, ∞, c, e, g, i, k, m
//
// which is the odd elements from the original list.
// which is the elements at the original even indices from the original list.

// The above loop will not hit the first element in the list (since we do not want to do a swap for the
// root element). So we have to handle this case specially at the end.
var firstOddIndex = destination.Length - extraElementsCount;
destination[firstOddIndex] = new Node(source[0], MaxEndNodeIndex: firstOddIndex);
// Destination will be equal to:
// ␀, ␀, ␀, ␀, ␀, ␀, ␀, a, c, e, g, i, k, m
// ∞, ∞, ∞, ∞, ∞, ∞, ∞, a, c, e, g, i, k, m
}

// Recursively build the perfect balanced subtree from the remaining elements, storing them into the start
// of the array. In the above example, this is building the perfect balanced tree for the event elements
// 8-14.
// of the array. In the above example, this is building the perfect balanced tree for the elements
// b, d, f, h, j, l, n.
BuildCompleteTreeRecursive(
source, destination, startInclusive: extraElementsCount, endExclusive: source.Count, destinationIndex: 0);
}
Expand Down Expand Up @@ -227,7 +217,7 @@ static int ComputeMaxEndNodes(SegmentedArray<Node> array, int currentNodeIndex,
// Now get the max end of the left and right children and compare to our end. Whichever is the rightmost
// endpoint is considered the max end index.
var currentNode = array[currentNodeIndex];
var thisEndValue = GetEnd(currentNode.Value, in introspector);
var thisEndValue = introspector.GetSpan(currentNode.Value).End;

if (thisEndValue >= leftMaxEndValue && thisEndValue >= rightMaxEndValue)
{
Expand Down Expand Up @@ -264,167 +254,51 @@ private static int GetLeftChildIndex(int nodeIndex)
private static int GetRightChildIndex(int nodeIndex)
=> (2 * nodeIndex) + 2;

private static int GetEnd<TIntrospector>(T value, in TIntrospector introspector)
where TIntrospector : struct, IIntervalIntrospector<T>
=> introspector.GetSpan(value).End;

bool IIntervalTree<T>.Any<TIntrospector>(int start, int length, TestInterval<T, TIntrospector> testInterval, in TIntrospector introspector)
{
// Inlined version of FillWithIntervalsThatMatch, optimized to do less work and stop once it finds a match.
var array = _array;
if (array.Length == 0)
return false;

using var _ = s_nodeIndexPool.GetPooledObject(out var candidates);

var end = start + length;

candidates.Push(0);

while (candidates.TryPop(out var currentNodeIndex))
{
// Check the nodes as we go down. That way we can stop immediately when we find something that matches,
// instead of having to do an entire in-order walk, which might end up hitting a lot of nodes we don't care
// about and placing a lot into the stack.
var node = array[currentNodeIndex];
if (testInterval(node.Value, start, length, in introspector))
return true;

if (ShouldExamineRight(array, start, end, currentNodeIndex, in introspector, out var rightIndex))
candidates.Push(rightIndex);

if (ShouldExamineLeft(array, start, currentNodeIndex, in introspector, out var leftIndex))
candidates.Push(leftIndex);
}

return false;
}
=> IntervalTreeHelpers<T, FlatArrayIntervalTree<T>, /*TNode*/ int, FlatArrayIntervalTreeHelper>.Any(this, start, length, testInterval, in introspector);

int IIntervalTree<T>.FillWithIntervalsThatMatch<TIntrospector>(
int start, int length, TestInterval<T, TIntrospector> testInterval,
ref TemporaryArray<T> builder, in TIntrospector introspector,
bool stopAfterFirst)
{
var array = _array;
if (array.Length == 0)
return 0;

using var _ = s_stackPool.GetPooledObject(out var candidates);

var matches = 0;
var end = start + length;

candidates.Push((nodeIndex: 0, firstTime: true));

while (candidates.TryPop(out var currentTuple))
{
var currentNodeIndex = currentTuple.nodeIndex;
var currentNode = array[currentNodeIndex];

if (!currentTuple.firstTime)
{
// We're seeing this node for the second time (as we walk back up the left
// side of it). Now see if it matches our test, and if so return it out.
if (testInterval(currentNode.Value, start, length, in introspector))
{
matches++;
builder.Add(currentNode.Value);

if (stopAfterFirst)
return 1;
}
}
else
{
// First time we're seeing this node. In order to see the node 'in-order', we push the right side, then
// the node again, then the left side. This time we mark the current node with 'false' to indicate that
// it's the second time we're seeing it the next time it comes around.

if (ShouldExamineRight(array, start, end, currentNodeIndex, in introspector, out var right))
candidates.Push((right, firstTime: true));

candidates.Push((currentNodeIndex, firstTime: false));

if (ShouldExamineLeft(array, start, currentNodeIndex, in introspector, out var left))
candidates.Push((left, firstTime: true));
}
}

return matches;
}

private static bool ShouldExamineRight<TIntrospector>(
SegmentedArray<Node> array,
int start,
int end,
int currentNodeIndex,
in TIntrospector introspector,
out int rightIndex) where TIntrospector : struct, IIntervalIntrospector<T>
{
// right children's starts will never be to the left of the parent's start so we should consider right
// subtree only if root's start overlaps with interval's End,
if (introspector.GetSpan(array[currentNodeIndex].Value).Start <= end)
{
rightIndex = GetRightChildIndex(currentNodeIndex);
if (rightIndex < array.Length && GetEnd(array[array[rightIndex].MaxEndNodeIndex].Value, in introspector) >= start)
return true;
}

rightIndex = 0;
return false;
}

private static bool ShouldExamineLeft<TIntrospector>(
SegmentedArray<Node> array,
int start,
int currentNodeIndex,
in TIntrospector introspector,
out int leftIndex) where TIntrospector : struct, IIntervalIntrospector<T>
{
// only if left's maxVal overlaps with interval's start, we should consider
// left subtree
leftIndex = GetLeftChildIndex(currentNodeIndex);
if (leftIndex < array.Length && GetEnd(array[array[leftIndex].MaxEndNodeIndex].Value, in introspector) >= start)
return true;

return false;
return IntervalTreeHelpers<T, FlatArrayIntervalTree<T>, /*TNode*/ int, FlatArrayIntervalTreeHelper>.FillWithIntervalsThatMatch(
this, start, length, testInterval, ref builder, in introspector, stopAfterFirst);
}

IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();

public IEnumerator<T> GetEnumerator()
=> IntervalTreeHelpers<T, FlatArrayIntervalTree<T>, /*TNode*/ int, FlatArrayIntervalTreeHelper>.GetEnumerator(this);

/// <summary>
/// Wrapper type to allow the IntervalTreeHelpers type to work with this type.
/// </summary>
private readonly struct FlatArrayIntervalTreeHelper : IIntervalTreeHelper<T, FlatArrayIntervalTree<T>, int>
{
var array = _array;
return array.Length == 0 ? SpecializedCollections.EmptyEnumerator<T>() : GetEnumeratorWorker(array);
public T GetValue(FlatArrayIntervalTree<T> tree, int node)
=> tree._array[node].Value;

static IEnumerator<T> GetEnumeratorWorker(SegmentedArray<Node> array)
{
using var _ = s_stackPool.GetPooledObject(out var candidates);
candidates.Push((0, firstTime: true));
while (candidates.TryPop(out var tuple))
{
var (currentNodeIndex, firstTime) = tuple;
if (firstTime)
{
// First time seeing this node. Mark that we've been seen and recurse down the left side. The
// next time we see this node we'll yield it out.
var rightIndex = GetRightChildIndex(currentNodeIndex);
var leftIndex = GetLeftChildIndex(currentNodeIndex);
public int GetMaxEndNode(FlatArrayIntervalTree<T> tree, int node)
=> tree._array[node].MaxEndNodeIndex;

if (rightIndex < array.Length)
candidates.Push((rightIndex, firstTime: true));
public bool TryGetRoot(FlatArrayIntervalTree<T> tree, out int root)
{
root = 0;
return tree._array.Length > 0;
}

candidates.Push((currentNodeIndex, firstTime: false));
public bool TryGetLeftNode(FlatArrayIntervalTree<T> tree, int node, out int leftNode)
{
leftNode = GetLeftChildIndex(node);
return leftNode < tree._array.Length;
}

if (leftIndex < array.Length)
candidates.Push((leftIndex, firstTime: true));
}
else
{
yield return array[currentNodeIndex].Value;
}
}
public bool TryGetRightNode(FlatArrayIntervalTree<T> tree, int node, out int rightNode)
{
rightNode = GetRightChildIndex(node);
return rightNode < tree._array.Length;
}
}
}
Loading
Loading