Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify passing around of FAR symbol groups #74358

Merged
merged 10 commits into from
Jul 12, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,26 @@ namespace Microsoft.CodeAnalysis.FindSymbols;

using Reference = (SymbolGroup group, ISymbol symbol, ReferenceLocation location);

internal partial class FindReferencesSearchEngine
internal partial class FindReferencesSearchEngine(
Solution solution,
IImmutableSet<Document>? documents,
ImmutableArray<IReferenceFinder> finders,
IStreamingFindReferencesProgress progress,
FindReferencesSearchOptions options)
{
private readonly Solution _solution;
private readonly IImmutableSet<Document>? _documents;
private readonly ImmutableArray<IReferenceFinder> _finders;
private readonly IStreamingProgressTracker _progressTracker;
private readonly IStreamingFindReferencesProgress _progress;
private readonly FindReferencesSearchOptions _options;

private static readonly TaskScheduler s_exclusiveScheduler = new ConcurrentExclusiveSchedulerPair().ExclusiveScheduler;

/// <summary>
/// Mapping from symbols (unified across metadata/retargeting) and the set of symbols that was produced for
/// them in the case of linked files across projects. This allows references to be found to any of the unified
/// symbols, while the user only gets a single reported group back that corresponds to that entire set.
/// Scheduler we use when we're doing operations in the BG and we want to rate limit them to not saturate the threadpool.
/// </summary>
private readonly ConcurrentDictionary<ISymbol, SymbolGroup> _symbolToGroup = new(MetadataUnifyingEquivalenceComparer.Instance);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this from being ambient state to just being a normal scratch buffer created within the Far algorithm (like all the other maps/sets/lists it has.


public FindReferencesSearchEngine(
Solution solution,
IImmutableSet<Document>? documents,
ImmutableArray<IReferenceFinder> finders,
IStreamingFindReferencesProgress progress,
FindReferencesSearchOptions options)
{
_documents = documents;
_solution = solution;
_finders = finders;
_progress = progress;
_options = options;
private static readonly TaskScheduler s_exclusiveScheduler = new ConcurrentExclusiveSchedulerPair().ExclusiveScheduler;

_progressTracker = progress.ProgressTracker;
}
private static readonly ObjectPool<Dictionary<ISymbol, SymbolGroup>> s_symbolToGroupPool = new(() => new(MetadataUnifyingEquivalenceComparer.Instance));

private readonly Solution _solution = solution;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do these need to be reassigned?

private readonly IImmutableSet<Document>? _documents = documents;
private readonly ImmutableArray<IReferenceFinder> _finders = finders;
private readonly IStreamingProgressTracker _progressTracker = progress.ProgressTracker;
private readonly IStreamingFindReferencesProgress _progress = progress;
private readonly FindReferencesSearchOptions _options = options;

/// <summary>
/// Options to control the parallelism of the search. If we're in <see
Expand Down Expand Up @@ -95,11 +82,19 @@ await ProducerConsumer<Reference>.RunAsync(
private async Task PerformSearchAsync(
ImmutableArray<ISymbol> symbols, Action<Reference> onReferenceFound, CancellationToken cancellationToken)
{
// Mapping from symbols (unified across metadata/retargeting) and the set of symbols that was produced for
// them in the case of linked files across projects. This allows references to be found to any of the unified
// symbols, while the user only gets a single reported group back that corresponds to that entire set.
//
// This is a normal dictionary that is not locked. It is only ever read and written to serially from within the
// high level project-walking code in this method.
using var _1 = s_symbolToGroupPool.GetPooledObject(out var symbolToGroup);

var unifiedSymbols = new MetadataUnifyingSymbolHashSet();
unifiedSymbols.AddRange(symbols);

var disposable = await _progressTracker.AddSingleItemAsync(cancellationToken).ConfigureAwait(false);
await using var _ = disposable.ConfigureAwait(false);
await using var _2 = disposable.ConfigureAwait(false);

// Create the initial set of symbols to search for. As we walk the appropriate projects in the solution
// we'll expand this set as we discover new symbols to search for in each project.
Expand All @@ -108,7 +103,9 @@ private async Task PerformSearchAsync(

// Report the initial set of symbols to the caller.
var allSymbols = symbolSet.GetAllSymbols();
await ReportGroupsAsync(allSymbols, cancellationToken).ConfigureAwait(false);

// Safe to call as we're in the entry-point method, and nothing is running concurrently with this call.
await ReportGroupsSeriallyAsync(allSymbols, symbolToGroup, cancellationToken).ConfigureAwait(false);

// Determine the set of projects we actually have to walk to find results in. If the caller provided a
// set of documents to search, we only bother with those.
Expand All @@ -118,15 +115,18 @@ private async Task PerformSearchAsync(

// Pull off and start searching each project as soon as we can once we've done the inheritance cascade into it.
await RoslynParallel.ForEachAsync(
GetProjectsAndSymbolsToSearchAsync(symbolSet, projectsToSearch, cancellationToken),
// ForEachAsync will serially pull on the IAsyncEnumerable returned here, kicking off the processing to then
// happen in parallel.
GetProjectsAndSymbolsToSearchSeriallyAsync(symbolSet, projectsToSearch, symbolToGroup, cancellationToken),
GetParallelOptions(cancellationToken),
async (tuple, cancellationToken) => await ProcessProjectAsync(
tuple.project, tuple.allSymbols, onReferenceFound, cancellationToken).ConfigureAwait(false)).ConfigureAwait(false);
}

private async IAsyncEnumerable<(Project project, ImmutableArray<ISymbol> allSymbols)> GetProjectsAndSymbolsToSearchAsync(
private async IAsyncEnumerable<(Project project, ImmutableArray<(ISymbol symbol, SymbolGroup group)> allSymbols)> GetProjectsAndSymbolsToSearchSeriallyAsync(
SymbolSet symbolSet,
ImmutableArray<Project> projectsToSearch,
Dictionary<ISymbol, SymbolGroup> symbolToGroup,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
// We need to process projects in order when updating our symbol set. Say we have three projects (A, B
Expand All @@ -146,12 +146,13 @@ await RoslynParallel.ForEachAsync(
// which is why we do it in this loop and not inside the concurrent project processing that happens
// below.
await symbolSet.InheritanceCascadeAsync(currentProject, cancellationToken).ConfigureAwait(false);
var allSymbols = symbolSet.GetAllSymbols();

// Report any new symbols we've cascaded to to our caller.
await ReportGroupsAsync(allSymbols, cancellationToken).ConfigureAwait(false);
// Report any new symbols we've cascaded to to our caller. This is safe to call here as we're abiding by
// the serial requirements of ReportGroupsSeriallyAsync
var allSymbolsAndGroups = await ReportGroupsSeriallyAsync(
symbolSet.GetAllSymbols(), symbolToGroup, cancellationToken).ConfigureAwait(false);

yield return (currentProject, allSymbols);
yield return (currentProject, allSymbolsAndGroups);
}
}

Expand All @@ -160,21 +161,28 @@ await RoslynParallel.ForEachAsync(
/// them once per symbol group, but we may have to notify about new symbols each time we expand our symbol set
/// when we walk into a new project.
/// </summary>
private async Task ReportGroupsAsync(ImmutableArray<ISymbol> symbols, CancellationToken cancellationToken)
private async Task<ImmutableArray<(ISymbol symbol, SymbolGroup group)>> ReportGroupsSeriallyAsync(
ImmutableArray<ISymbol> symbols, Dictionary<ISymbol, SymbolGroup> symbolToGroup, CancellationToken cancellationToken)
{
var result = new FixedSizeArrayBuilder<(ISymbol symbol, SymbolGroup group)>(symbols.Length);

// Safe to call this as we're only being called from within a serial context ourselves.
foreach (var symbol in symbols)
await ReportGroupAsync(symbol, cancellationToken).ConfigureAwait(false);
result.Add((symbol, await ReportGroupSeriallyAsync(symbol, symbolToGroup, cancellationToken).ConfigureAwait(false)));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we ensure that the symbol and its group always travel together.


return result.MoveToImmutable();
}

private async ValueTask<SymbolGroup> ReportGroupAsync(ISymbol symbol, CancellationToken cancellationToken)
private async ValueTask<SymbolGroup> ReportGroupSeriallyAsync(
ISymbol symbol, Dictionary<ISymbol, SymbolGroup> symbolToGroup, CancellationToken cancellationToken)
{
// See if this is the first time we're running across this symbol. Note: no locks are needed
// here between checking and then adding because this is only ever called serially from within
// FindReferencesAsync above (though we still need a ConcurrentDictionary as reads of these
// symbols will happen later in ProcessDocumentAsync. However, those reads will only happen
// after the dependent symbol values were written in, so it will be safe to blindly read them
// out.
if (!_symbolToGroup.TryGetValue(symbol, out var group))
if (!symbolToGroup.TryGetValue(symbol, out var group))
{
var linkedSymbols = await SymbolFinder.FindLinkedSymbolsAsync(symbol, _solution, cancellationToken).ConfigureAwait(false);
Contract.ThrowIfFalse(linkedSymbols.Contains(symbol), "Linked symbols did not contain the very symbol we started with.");
Expand All @@ -183,11 +191,11 @@ private async ValueTask<SymbolGroup> ReportGroupAsync(ISymbol symbol, Cancellati
Contract.ThrowIfFalse(group.Symbols.Contains(symbol), "Symbol group did not contain the very symbol we started with.");

foreach (var groupSymbol in group.Symbols)
_symbolToGroup.TryAdd(groupSymbol, group);
symbolToGroup.TryAdd(groupSymbol, group);

// Since "symbol" was in group.Symbols, and we just added links from all of group.Symbols to that group, then "symbol"
// better now be in _symbolToGroup.
Contract.ThrowIfFalse(_symbolToGroup.ContainsKey(symbol));
Contract.ThrowIfFalse(symbolToGroup.ContainsKey(symbol));

await _progress.OnDefinitionFoundAsync(group, cancellationToken).ConfigureAwait(false);
}
Expand All @@ -206,18 +214,18 @@ private Task<ImmutableArray<Project>> GetProjectsToSearchAsync(
}

private async ValueTask ProcessProjectAsync(
Project project, ImmutableArray<ISymbol> allSymbols, Action<Reference> onReferenceFound, CancellationToken cancellationToken)
Project project, ImmutableArray<(ISymbol symbol, SymbolGroup group)> allSymbols, Action<Reference> onReferenceFound, CancellationToken cancellationToken)
{
using var _1 = PooledDictionary<ISymbol, PooledHashSet<string>>.GetInstance(out var symbolToGlobalAliases);
using var _2 = PooledDictionary<Document, MetadataUnifyingSymbolHashSet>.GetInstance(out var documentToSymbols);
using var _2 = PooledDictionary<Document, Dictionary<ISymbol, SymbolGroup>>.GetInstance(out var documentToSymbolsWithin);
try
{
// scratch hashset to place results in. Populated/inspected/cleared in inner loop.
using var _3 = PooledHashSet<Document>.GetInstance(out var foundDocuments);

await AddGlobalAliasesAsync(project, allSymbols, symbolToGlobalAliases, cancellationToken).ConfigureAwait(false);

foreach (var symbol in allSymbols)
foreach (var (symbol, group) in allSymbols)
{
var globalAliases = TryGet(symbolToGlobalAliases, symbol);

Expand All @@ -230,38 +238,41 @@ await finder.DetermineDocumentsToSearchAsync(
_options, cancellationToken).ConfigureAwait(false);

foreach (var document in foundDocuments)
GetSymbolSet(documentToSymbols, document).Add(symbol);
{
var symbolsWithin = documentToSymbolsWithin.GetOrAdd(document, static _ => s_symbolToGroupPool.AllocateAndClear());
symbolsWithin[symbol] = group;
}

foundDocuments.Clear();
}
}

await RoslynParallel.ForEachAsync(
documentToSymbols,
documentToSymbolsWithin,
GetParallelOptions(cancellationToken),
(kvp, cancellationToken) =>
ProcessDocumentAsync(kvp.Key, kvp.Value, symbolToGlobalAliases, onReferenceFound, cancellationToken)).ConfigureAwait(false);
}
finally
{
foreach (var (_, symbols) in documentToSymbols)
MetadataUnifyingSymbolHashSet.ClearAndFree(symbols);
foreach (var (_, symbolsWithin) in documentToSymbolsWithin)
{
symbolsWithin.Clear();
s_symbolToGroupPool.Free(symbolsWithin);
}

FreeGlobalAliases(symbolToGlobalAliases);

await _progressTracker.ItemCompletedAsync(cancellationToken).ConfigureAwait(false);
}

static MetadataUnifyingSymbolHashSet GetSymbolSet<T>(PooledDictionary<T, MetadataUnifyingSymbolHashSet> dictionary, T key) where T : notnull
=> dictionary.GetOrAdd(key, static _ => MetadataUnifyingSymbolHashSet.AllocateFromPool());
}

private static PooledHashSet<U>? TryGet<T, U>(Dictionary<T, PooledHashSet<U>> dictionary, T key) where T : notnull
=> dictionary.TryGetValue(key, out var set) ? set : null;

private async ValueTask ProcessDocumentAsync(
Document document,
MetadataUnifyingSymbolHashSet symbols,
Dictionary<ISymbol, SymbolGroup> symbolsToSearchFor,
Dictionary<ISymbol, PooledHashSet<string>> symbolToGlobalAliases,
Action<Reference> onReferenceFound,
CancellationToken cancellationToken)
Expand All @@ -281,48 +292,46 @@ private async ValueTask ProcessDocumentAsync(
// Note: cascaded symbols will normally have the same name. That's ok. The second call to
// FindMatchingIdentifierTokens with the same name will short circuit since it will already see the result of
// the prior call.
foreach (var symbol in symbols)
foreach (var (symbol, _) in symbolsToSearchFor)
{
if (symbol.CanBeReferencedByName)
cache.FindMatchingIdentifierTokens(symbol.Name, cancellationToken);
}

await RoslynParallel.ForEachAsync(
symbols,
symbolsToSearchFor,
GetParallelOptions(cancellationToken),
(symbol, cancellationToken) =>
(kvp, cancellationToken) =>
{
var (symbolToSearchFor, symbolGroup) = kvp;

// symbolToGlobalAliases is safe to read in parallel. It is created fully before this point and is no
// longer mutated.
var state = new FindReferencesDocumentState(
cache, TryGet(symbolToGlobalAliases, symbol));
cache, TryGet(symbolToGlobalAliases, symbolToSearchFor));

ProcessDocument(symbol, state, onReferenceFound);
ProcessDocument(symbolToSearchFor, symbolGroup, state, onReferenceFound);
return ValueTaskFactory.CompletedTask;
}).ConfigureAwait(false);

return;

void ProcessDocument(
ISymbol symbol, FindReferencesDocumentState state, Action<Reference> onReferenceFound)
ISymbol symbolToSearchFor, SymbolGroup symbolGroup, FindReferencesDocumentState state, Action<Reference> onReferenceFound)
{
cancellationToken.ThrowIfCancellationRequested();

using (Logger.LogBlock(FunctionId.FindReference_ProcessDocumentAsync, cancellationToken))
{
// This is safe to just blindly read. We can only ever get here after the call to ReportGroupsAsync
// happened. So there must be a group for this symbol in our map.
var group = _symbolToGroup[symbol];
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the crux of hte change is that instead of having a collection be blindly written and read from from disparate locations, we just take t6he point that we write the data initially and pass that data along all the way to here so we're never out of sync.


// Note: nearly every finder will no-op when passed a in a symbol it's not applicable to. So it's
// simple to just iterate over all of them, knowing that will quickly skip all the irrelevant ones,
// and only do interesting work on the single relevant one.
foreach (var finder in _finders)
{
finder.FindReferencesInDocument(
symbol, state,
static (loc, tuple) => tuple.onReferenceFound((tuple.group, tuple.symbol, loc.Location)),
(group, symbol, onReferenceFound),
symbolToSearchFor, state,
static (loc, tuple) => tuple.onReferenceFound((tuple.symbolGroup, tuple.symbolToSearchFor, loc.Location)),
(symbolGroup, symbolToSearchFor, onReferenceFound),
_options,
cancellationToken);
}
Expand All @@ -332,11 +341,11 @@ void ProcessDocument(

private async Task AddGlobalAliasesAsync(
Project project,
ImmutableArray<ISymbol> allSymbols,
ImmutableArray<(ISymbol symbol, SymbolGroup group)> allSymbols,
PooledDictionary<ISymbol, PooledHashSet<string>> symbolToGlobalAliases,
CancellationToken cancellationToken)
{
foreach (var symbol in allSymbols)
foreach (var (symbol, _) in allSymbols)
{
foreach (var finder in _finders)
{
Expand Down
Loading
Loading