Skip to content

Commit

Permalink
Merge pull request #74358 from CyrusNajmabadi/farGrups
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrusNajmabadi committed Jul 12, 2024
2 parents f796860 + 0b0cb32 commit b05da42
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,26 @@ namespace Microsoft.CodeAnalysis.FindSymbols;

using Reference = (SymbolGroup group, ISymbol symbol, ReferenceLocation location);

internal partial class FindReferencesSearchEngine
internal partial class FindReferencesSearchEngine(
Solution solution,
IImmutableSet<Document>? documents,
ImmutableArray<IReferenceFinder> finders,
IStreamingFindReferencesProgress progress,
FindReferencesSearchOptions options)
{
private readonly Solution _solution;
private readonly IImmutableSet<Document>? _documents;
private readonly ImmutableArray<IReferenceFinder> _finders;
private readonly IStreamingProgressTracker _progressTracker;
private readonly IStreamingFindReferencesProgress _progress;
private readonly FindReferencesSearchOptions _options;

private static readonly TaskScheduler s_exclusiveScheduler = new ConcurrentExclusiveSchedulerPair().ExclusiveScheduler;

/// <summary>
/// Mapping from symbols (unified across metadata/retargeting) and the set of symbols that was produced for
/// them in the case of linked files across projects. This allows references to be found to any of the unified
/// symbols, while the user only gets a single reported group back that corresponds to that entire set.
/// Scheduler we use when we're doing operations in the BG and we want to rate limit them to not saturate the threadpool.
/// </summary>
private readonly ConcurrentDictionary<ISymbol, SymbolGroup> _symbolToGroup = new(MetadataUnifyingEquivalenceComparer.Instance);

public FindReferencesSearchEngine(
Solution solution,
IImmutableSet<Document>? documents,
ImmutableArray<IReferenceFinder> finders,
IStreamingFindReferencesProgress progress,
FindReferencesSearchOptions options)
{
_documents = documents;
_solution = solution;
_finders = finders;
_progress = progress;
_options = options;
private static readonly TaskScheduler s_exclusiveScheduler = new ConcurrentExclusiveSchedulerPair().ExclusiveScheduler;

_progressTracker = progress.ProgressTracker;
}
private static readonly ObjectPool<Dictionary<ISymbol, SymbolGroup>> s_symbolToGroupPool = new(() => new(MetadataUnifyingEquivalenceComparer.Instance));

private readonly Solution _solution = solution;
private readonly IImmutableSet<Document>? _documents = documents;
private readonly ImmutableArray<IReferenceFinder> _finders = finders;
private readonly IStreamingProgressTracker _progressTracker = progress.ProgressTracker;
private readonly IStreamingFindReferencesProgress _progress = progress;
private readonly FindReferencesSearchOptions _options = options;

/// <summary>
/// Options to control the parallelism of the search. If we're in <see
Expand Down Expand Up @@ -95,11 +82,19 @@ await ProducerConsumer<Reference>.RunAsync(
private async Task PerformSearchAsync(
ImmutableArray<ISymbol> symbols, Action<Reference> onReferenceFound, CancellationToken cancellationToken)
{
// Mapping from symbols (unified across metadata/retargeting) and the set of symbols that was produced for
// them in the case of linked files across projects. This allows references to be found to any of the unified
// symbols, while the user only gets a single reported group back that corresponds to that entire set.
//
// This is a normal dictionary that is not locked. It is only ever read and written to serially from within the
// high level project-walking code in this method.
using var _1 = s_symbolToGroupPool.GetPooledObject(out var symbolToGroup);

var unifiedSymbols = new MetadataUnifyingSymbolHashSet();
unifiedSymbols.AddRange(symbols);

var disposable = await _progressTracker.AddSingleItemAsync(cancellationToken).ConfigureAwait(false);
await using var _ = disposable.ConfigureAwait(false);
await using var _2 = disposable.ConfigureAwait(false);

// Create the initial set of symbols to search for. As we walk the appropriate projects in the solution
// we'll expand this set as we discover new symbols to search for in each project.
Expand All @@ -108,7 +103,9 @@ private async Task PerformSearchAsync(

// Report the initial set of symbols to the caller.
var allSymbols = symbolSet.GetAllSymbols();
await ReportGroupsAsync(allSymbols, cancellationToken).ConfigureAwait(false);

// Safe to call as we're in the entry-point method, and nothing is running concurrently with this call.
await ReportGroupsSeriallyAsync(allSymbols, symbolToGroup, cancellationToken).ConfigureAwait(false);

// Determine the set of projects we actually have to walk to find results in. If the caller provided a
// set of documents to search, we only bother with those.
Expand All @@ -118,15 +115,18 @@ private async Task PerformSearchAsync(

// Pull off and start searching each project as soon as we can once we've done the inheritance cascade into it.
await RoslynParallel.ForEachAsync(
GetProjectsAndSymbolsToSearchAsync(symbolSet, projectsToSearch, cancellationToken),
// ForEachAsync will serially pull on the IAsyncEnumerable returned here, kicking off the processing to then
// happen in parallel.
GetProjectsAndSymbolsToSearchSeriallyAsync(symbolSet, projectsToSearch, symbolToGroup, cancellationToken),
GetParallelOptions(cancellationToken),
async (tuple, cancellationToken) => await ProcessProjectAsync(
tuple.project, tuple.allSymbols, onReferenceFound, cancellationToken).ConfigureAwait(false)).ConfigureAwait(false);
}

private async IAsyncEnumerable<(Project project, ImmutableArray<ISymbol> allSymbols)> GetProjectsAndSymbolsToSearchAsync(
private async IAsyncEnumerable<(Project project, ImmutableArray<(ISymbol symbol, SymbolGroup group)> allSymbols)> GetProjectsAndSymbolsToSearchSeriallyAsync(
SymbolSet symbolSet,
ImmutableArray<Project> projectsToSearch,
Dictionary<ISymbol, SymbolGroup> symbolToGroup,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
// We need to process projects in order when updating our symbol set. Say we have three projects (A, B
Expand All @@ -146,12 +146,13 @@ await RoslynParallel.ForEachAsync(
// which is why we do it in this loop and not inside the concurrent project processing that happens
// below.
await symbolSet.InheritanceCascadeAsync(currentProject, cancellationToken).ConfigureAwait(false);
var allSymbols = symbolSet.GetAllSymbols();

// Report any new symbols we've cascaded to to our caller.
await ReportGroupsAsync(allSymbols, cancellationToken).ConfigureAwait(false);
// Report any new symbols we've cascaded to to our caller. This is safe to call here as we're abiding by
// the serial requirements of ReportGroupsSeriallyAsync
var allSymbolsAndGroups = await ReportGroupsSeriallyAsync(
symbolSet.GetAllSymbols(), symbolToGroup, cancellationToken).ConfigureAwait(false);

yield return (currentProject, allSymbols);
yield return (currentProject, allSymbolsAndGroups);
}
}

Expand All @@ -160,21 +161,28 @@ await RoslynParallel.ForEachAsync(
/// them once per symbol group, but we may have to notify about new symbols each time we expand our symbol set
/// when we walk into a new project.
/// </summary>
private async Task ReportGroupsAsync(ImmutableArray<ISymbol> symbols, CancellationToken cancellationToken)
private async Task<ImmutableArray<(ISymbol symbol, SymbolGroup group)>> ReportGroupsSeriallyAsync(
ImmutableArray<ISymbol> symbols, Dictionary<ISymbol, SymbolGroup> symbolToGroup, CancellationToken cancellationToken)
{
var result = new FixedSizeArrayBuilder<(ISymbol symbol, SymbolGroup group)>(symbols.Length);

// Safe to call this as we're only being called from within a serial context ourselves.
foreach (var symbol in symbols)
await ReportGroupAsync(symbol, cancellationToken).ConfigureAwait(false);
result.Add((symbol, await ReportGroupSeriallyAsync(symbol, symbolToGroup, cancellationToken).ConfigureAwait(false)));

return result.MoveToImmutable();
}

private async ValueTask<SymbolGroup> ReportGroupAsync(ISymbol symbol, CancellationToken cancellationToken)
private async ValueTask<SymbolGroup> ReportGroupSeriallyAsync(
ISymbol symbol, Dictionary<ISymbol, SymbolGroup> symbolToGroup, CancellationToken cancellationToken)
{
// See if this is the first time we're running across this symbol. Note: no locks are needed
// here between checking and then adding because this is only ever called serially from within
// FindReferencesAsync above (though we still need a ConcurrentDictionary as reads of these
// symbols will happen later in ProcessDocumentAsync. However, those reads will only happen
// after the dependent symbol values were written in, so it will be safe to blindly read them
// out.
if (!_symbolToGroup.TryGetValue(symbol, out var group))
if (!symbolToGroup.TryGetValue(symbol, out var group))
{
var linkedSymbols = await SymbolFinder.FindLinkedSymbolsAsync(symbol, _solution, cancellationToken).ConfigureAwait(false);
Contract.ThrowIfFalse(linkedSymbols.Contains(symbol), "Linked symbols did not contain the very symbol we started with.");
Expand All @@ -183,11 +191,11 @@ private async ValueTask<SymbolGroup> ReportGroupAsync(ISymbol symbol, Cancellati
Contract.ThrowIfFalse(group.Symbols.Contains(symbol), "Symbol group did not contain the very symbol we started with.");

foreach (var groupSymbol in group.Symbols)
_symbolToGroup.TryAdd(groupSymbol, group);
symbolToGroup.TryAdd(groupSymbol, group);

// Since "symbol" was in group.Symbols, and we just added links from all of group.Symbols to that group, then "symbol"
// better now be in _symbolToGroup.
Contract.ThrowIfFalse(_symbolToGroup.ContainsKey(symbol));
Contract.ThrowIfFalse(symbolToGroup.ContainsKey(symbol));

await _progress.OnDefinitionFoundAsync(group, cancellationToken).ConfigureAwait(false);
}
Expand All @@ -206,18 +214,18 @@ private Task<ImmutableArray<Project>> GetProjectsToSearchAsync(
}

private async ValueTask ProcessProjectAsync(
Project project, ImmutableArray<ISymbol> allSymbols, Action<Reference> onReferenceFound, CancellationToken cancellationToken)
Project project, ImmutableArray<(ISymbol symbol, SymbolGroup group)> allSymbols, Action<Reference> onReferenceFound, CancellationToken cancellationToken)
{
using var _1 = PooledDictionary<ISymbol, PooledHashSet<string>>.GetInstance(out var symbolToGlobalAliases);
using var _2 = PooledDictionary<Document, MetadataUnifyingSymbolHashSet>.GetInstance(out var documentToSymbols);
using var _2 = PooledDictionary<Document, Dictionary<ISymbol, SymbolGroup>>.GetInstance(out var documentToSymbolsWithin);
try
{
// scratch hashset to place results in. Populated/inspected/cleared in inner loop.
using var _3 = PooledHashSet<Document>.GetInstance(out var foundDocuments);

await AddGlobalAliasesAsync(project, allSymbols, symbolToGlobalAliases, cancellationToken).ConfigureAwait(false);

foreach (var symbol in allSymbols)
foreach (var (symbol, group) in allSymbols)
{
var globalAliases = TryGet(symbolToGlobalAliases, symbol);

Expand All @@ -230,38 +238,41 @@ await finder.DetermineDocumentsToSearchAsync(
_options, cancellationToken).ConfigureAwait(false);

foreach (var document in foundDocuments)
GetSymbolSet(documentToSymbols, document).Add(symbol);
{
var symbolsWithin = documentToSymbolsWithin.GetOrAdd(document, static _ => s_symbolToGroupPool.AllocateAndClear());
symbolsWithin[symbol] = group;
}

foundDocuments.Clear();
}
}

await RoslynParallel.ForEachAsync(
documentToSymbols,
documentToSymbolsWithin,
GetParallelOptions(cancellationToken),
(kvp, cancellationToken) =>
ProcessDocumentAsync(kvp.Key, kvp.Value, symbolToGlobalAliases, onReferenceFound, cancellationToken)).ConfigureAwait(false);
}
finally
{
foreach (var (_, symbols) in documentToSymbols)
MetadataUnifyingSymbolHashSet.ClearAndFree(symbols);
foreach (var (_, symbolsWithin) in documentToSymbolsWithin)
{
symbolsWithin.Clear();
s_symbolToGroupPool.Free(symbolsWithin);
}

FreeGlobalAliases(symbolToGlobalAliases);

await _progressTracker.ItemCompletedAsync(cancellationToken).ConfigureAwait(false);
}

static MetadataUnifyingSymbolHashSet GetSymbolSet<T>(PooledDictionary<T, MetadataUnifyingSymbolHashSet> dictionary, T key) where T : notnull
=> dictionary.GetOrAdd(key, static _ => MetadataUnifyingSymbolHashSet.AllocateFromPool());
}

private static PooledHashSet<U>? TryGet<T, U>(Dictionary<T, PooledHashSet<U>> dictionary, T key) where T : notnull
=> dictionary.TryGetValue(key, out var set) ? set : null;

private async ValueTask ProcessDocumentAsync(
Document document,
MetadataUnifyingSymbolHashSet symbols,
Dictionary<ISymbol, SymbolGroup> symbolsToSearchFor,
Dictionary<ISymbol, PooledHashSet<string>> symbolToGlobalAliases,
Action<Reference> onReferenceFound,
CancellationToken cancellationToken)
Expand All @@ -281,48 +292,46 @@ private async ValueTask ProcessDocumentAsync(
// Note: cascaded symbols will normally have the same name. That's ok. The second call to
// FindMatchingIdentifierTokens with the same name will short circuit since it will already see the result of
// the prior call.
foreach (var symbol in symbols)
foreach (var (symbol, _) in symbolsToSearchFor)
{
if (symbol.CanBeReferencedByName)
cache.FindMatchingIdentifierTokens(symbol.Name, cancellationToken);
}

await RoslynParallel.ForEachAsync(
symbols,
symbolsToSearchFor,
GetParallelOptions(cancellationToken),
(symbol, cancellationToken) =>
(kvp, cancellationToken) =>
{
var (symbolToSearchFor, symbolGroup) = kvp;
// symbolToGlobalAliases is safe to read in parallel. It is created fully before this point and is no
// longer mutated.
var state = new FindReferencesDocumentState(
cache, TryGet(symbolToGlobalAliases, symbol));
cache, TryGet(symbolToGlobalAliases, symbolToSearchFor));
ProcessDocument(symbol, state, onReferenceFound);
ProcessDocument(symbolToSearchFor, symbolGroup, state, onReferenceFound);
return ValueTaskFactory.CompletedTask;
}).ConfigureAwait(false);

return;

void ProcessDocument(
ISymbol symbol, FindReferencesDocumentState state, Action<Reference> onReferenceFound)
ISymbol symbolToSearchFor, SymbolGroup symbolGroup, FindReferencesDocumentState state, Action<Reference> onReferenceFound)
{
cancellationToken.ThrowIfCancellationRequested();

using (Logger.LogBlock(FunctionId.FindReference_ProcessDocumentAsync, cancellationToken))
{
// This is safe to just blindly read. We can only ever get here after the call to ReportGroupsAsync
// happened. So there must be a group for this symbol in our map.
var group = _symbolToGroup[symbol];

// Note: nearly every finder will no-op when passed a in a symbol it's not applicable to. So it's
// simple to just iterate over all of them, knowing that will quickly skip all the irrelevant ones,
// and only do interesting work on the single relevant one.
foreach (var finder in _finders)
{
finder.FindReferencesInDocument(
symbol, state,
static (loc, tuple) => tuple.onReferenceFound((tuple.group, tuple.symbol, loc.Location)),
(group, symbol, onReferenceFound),
symbolToSearchFor, state,
static (loc, tuple) => tuple.onReferenceFound((tuple.symbolGroup, tuple.symbolToSearchFor, loc.Location)),
(symbolGroup, symbolToSearchFor, onReferenceFound),
_options,
cancellationToken);
}
Expand All @@ -332,11 +341,11 @@ void ProcessDocument(

private async Task AddGlobalAliasesAsync(
Project project,
ImmutableArray<ISymbol> allSymbols,
ImmutableArray<(ISymbol symbol, SymbolGroup group)> allSymbols,
PooledDictionary<ISymbol, PooledHashSet<string>> symbolToGlobalAliases,
CancellationToken cancellationToken)
{
foreach (var symbol in allSymbols)
foreach (var (symbol, _) in allSymbols)
{
foreach (var finder in _finders)
{
Expand Down
Loading

0 comments on commit b05da42

Please sign in to comment.