Skip to content

Commit

Permalink
Add more cases supported by 'use collection expression' (#75879)
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrusNajmabadi authored Nov 14, 2024
2 parents 6acf726 + 18a8664 commit 747210f
Show file tree
Hide file tree
Showing 19 changed files with 570 additions and 93 deletions.
1 change: 1 addition & 0 deletions src/Analyzers/CSharp/Analyzers/CSharpAnalyzers.projitems
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
<Compile Include="$(MSBuildThisFileDirectory)UseCoalesceExpression\UseCoalesceExpressionHelpers.cs" />
<Compile Include="$(MSBuildThisFileDirectory)UseCollectionExpression\AbstractCSharpUseCollectionExpressionDiagnosticAnalyzer.cs" />
<Compile Include="$(MSBuildThisFileDirectory)UseCollectionExpression\CSharpUseCollectionExpressionForBuilderDiagnosticAnalyzer.cs" />
<Compile Include="$(MSBuildThisFileDirectory)UseCollectionExpression\CSharpUseCollectionExpressionForNewDiagnosticAnalyzer.cs" />
<Compile Include="$(MSBuildThisFileDirectory)UseCollectionExpression\CSharpUseCollectionExpressionForFluentDiagnosticAnalyzer.cs" />
<Compile Include="$(MSBuildThisFileDirectory)UseCollectionExpression\CSharpUseCollectionExpressionForEmptyDiagnosticAnalyzer.cs" />
<Compile Include="$(MSBuildThisFileDirectory)UseCollectionExpression\CSharpUseCollectionExpressionForArrayDiagnosticAnalyzer.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ internal sealed partial class CSharpUseCollectionExpressionForCreateDiagnosticAn
IDEDiagnosticIds.UseCollectionExpressionForCreateDiagnosticId,
EnforceOnBuildValues.UseCollectionExpressionForCreate)
{
public const string UnwrapArgument = nameof(UnwrapArgument);

private static readonly ImmutableDictionary<string, string?> s_unwrapArgumentProperties =
ImmutableDictionary<string, string?>.Empty.Add(UnwrapArgument, UnwrapArgument);

protected override void InitializeWorker(CodeBlockStartAnalysisContext<SyntaxKind> context, INamedTypeSymbol? expressionType)
=> context.RegisterSyntaxNodeAction(context => AnalyzeInvocationExpression(context, expressionType), SyntaxKind.InvocationExpression);

Expand All @@ -40,7 +35,7 @@ private void AnalyzeInvocationExpression(SyntaxNodeAnalysisContext context, INam
return;

var invocationExpression = (InvocationExpressionSyntax)context.Node;
if (!IsCollectionFactoryCreate(semanticModel, invocationExpression, out var memberAccess, out var unwrapArgument, cancellationToken))
if (!IsCollectionFactoryCreate(semanticModel, invocationExpression, out var memberAccess, out var unwrapArgument, out var useSpread, cancellationToken))
return;

// Make sure we can actually use a collection expression in place of the full invocation.
Expand All @@ -52,9 +47,7 @@ private void AnalyzeInvocationExpression(SyntaxNodeAnalysisContext context, INam
}

var locations = ImmutableArray.Create(invocationExpression.GetLocation());
var properties = unwrapArgument ? s_unwrapArgumentProperties : ImmutableDictionary<string, string?>.Empty;
if (changesSemantics)
properties = properties.Add(UseCollectionInitializerHelpers.ChangesSemanticsName, "");
var properties = GetDiagnosticProperties(unwrapArgument, useSpread, changesSemantics);

context.ReportDiagnostic(DiagnosticHelper.Create(
Descriptor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,6 @@ internal sealed partial class CSharpUseCollectionExpressionForFluentDiagnosticAn
nameof(System.Collections.Immutable),
];

/// <summary>
/// Set of type-names that are blocked from moving over to collection expressions because the semantics of them are
/// known to be specialized, and thus could change semantics in undesirable ways if the compiler emitted its own
/// code as an replacement.
/// </summary>
private static readonly ImmutableHashSet<string?> s_bannedTypes = [
nameof(ParallelEnumerable),
nameof(ParallelQuery),
// Special internal runtime interface that is optimized for fast path conversions of collections.
"IIListProvider"];

protected override void InitializeWorker(CodeBlockStartAnalysisContext<SyntaxKind> context, INamedTypeSymbol? expressionType)
=> context.RegisterSyntaxNodeAction(context => AnalyzeMemberAccess(context, expressionType), SyntaxKind.SimpleMemberAccessExpression);

Expand Down Expand Up @@ -272,12 +261,12 @@ private static bool AnalyzeInvocation(

// Forms like `ImmutableArray.Create(...)` or `ImmutableArray.CreateRange(...)` are fine base cases.
if (current is InvocationExpressionSyntax currentInvocationExpression &&
IsCollectionFactoryCreate(semanticModel, currentInvocationExpression, out var factoryMemberAccess, out var unwrapArgument, cancellationToken))
IsCollectionFactoryCreate(semanticModel, currentInvocationExpression, out var factoryMemberAccess, out var unwrapArgument, out var useSpread, cancellationToken))
{
if (!IsListLike(current))
return false;

AddArgumentsInReverse(postMatchesInReverse, GetArguments(currentInvocationExpression, unwrapArgument), useSpread: false);
AddArgumentsInReverse(postMatchesInReverse, GetArguments(currentInvocationExpression.ArgumentList, unwrapArgument), useSpread);
return true;
}

Expand All @@ -292,7 +281,7 @@ private static bool AnalyzeInvocation(
// Down to some final collection. Like `x` in `x.Concat(y).ToArray()`. If `x` is itself is something that
// can be iterated, we can convert this to `[.. x, .. y]`. Note: we only want to do this if ending with one
// of the ToXXX Methods. If we just have `x.AddRange(y)` it's preference to keep that, versus `[.. x, ..y]`
if (!isAdditionMatch && IsIterable(current))
if (!isAdditionMatch && IsIterable(semanticModel, current, cancellationToken))
{
AddFinalMatch(current);
return true;
Expand Down Expand Up @@ -341,36 +330,8 @@ bool IsListLike(ExpressionSyntax expression)
return false;

return
Implements(type, compilation.IListOfTType()) ||
Implements(type, compilation.IReadOnlyListOfTType());
}

bool IsIterable(ExpressionSyntax expression)
{
var type = semanticModel.GetTypeInfo(expression, cancellationToken).Type;
if (type is null or IErrorTypeSymbol)
return false;

if (s_bannedTypes.Contains(type.Name))
return false;

return Implements(type, compilation.IEnumerableOfTType()) ||
type.Equals(compilation.SpanOfTType()) ||
type.Equals(compilation.ReadOnlySpanOfTType());
}

static bool Implements(ITypeSymbol type, INamedTypeSymbol? interfaceType)
{
if (interfaceType != null)
{
foreach (var baseInterface in type.AllInterfaces)
{
if (interfaceType.Equals(baseInterface.OriginalDefinition))
return true;
}
}

return false;
EqualsOrImplements(type, compilation.IListOfTType()) ||
EqualsOrImplements(type, compilation.IReadOnlyListOfTType());
}

static bool IsLegalInitializer(InitializerExpressionSyntax? initializer)
Expand Down Expand Up @@ -426,12 +387,12 @@ private static bool IsMatch(
// Check to make sure we're not calling something banned because it would change semantics. First check if the
// method itself comes from a banned type (like with an extension method).
var member = state.SemanticModel.GetSymbolInfo(memberAccess, cancellationToken).Symbol;
if (s_bannedTypes.Contains(member?.ContainingType.Name))
if (BannedTypes.Contains(member?.ContainingType.Name))
return false;

// Next, check if we're invoking this on a banned type.
var type = state.SemanticModel.GetTypeInfo(memberAccess.Expression, cancellationToken).Type;
if (s_bannedTypes.Contains(type?.Name))
if (BannedTypes.Contains(type?.Name))
return false;

return true;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.Collections.Immutable;
using Microsoft.CodeAnalysis.CodeStyle;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Shared.CodeStyle;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Text;
using Microsoft.CodeAnalysis.UseCollectionInitializer;

namespace Microsoft.CodeAnalysis.CSharp.UseCollectionExpression;

using static UseCollectionExpressionHelpers;

[DiagnosticAnalyzer(LanguageNames.CSharp)]
internal sealed partial class CSharpUseCollectionExpressionForNewDiagnosticAnalyzer()
: AbstractCSharpUseCollectionExpressionDiagnosticAnalyzer(
IDEDiagnosticIds.UseCollectionExpressionForNewDiagnosticId,
EnforceOnBuildValues.UseCollectionExpressionForNew)
{
protected override void InitializeWorker(CodeBlockStartAnalysisContext<SyntaxKind> context, INamedTypeSymbol? expressionType)
{
context.RegisterSyntaxNodeAction(context => AnalyzeObjectCreationExpression(context, expressionType), SyntaxKind.ObjectCreationExpression);
context.RegisterSyntaxNodeAction(context => AnalyzeImplicitObjectCreationExpression(context, expressionType), SyntaxKind.ImplicitObjectCreationExpression);
}

private void AnalyzeObjectCreationExpression(SyntaxNodeAnalysisContext context, INamedTypeSymbol? expressionType)
=> AnalyzeBaseObjectCreationExpression(context, (BaseObjectCreationExpressionSyntax)context.Node, expressionType);

private void AnalyzeImplicitObjectCreationExpression(SyntaxNodeAnalysisContext context, INamedTypeSymbol? expressionType)
=> AnalyzeBaseObjectCreationExpression(context, (BaseObjectCreationExpressionSyntax)context.Node, expressionType);

private void AnalyzeBaseObjectCreationExpression(
SyntaxNodeAnalysisContext context, BaseObjectCreationExpressionSyntax objectCreationExpression, INamedTypeSymbol? expressionType)
{
var semanticModel = context.SemanticModel;
var compilation = semanticModel.Compilation;
var syntaxTree = semanticModel.SyntaxTree;
var cancellationToken = context.CancellationToken;

if (objectCreationExpression is not { ArgumentList.Arguments: [var argument], Initializer: null })
return;

// no point in analyzing if the option is off.
var option = context.GetAnalyzerOptions().PreferCollectionExpression;
if (option.Value is CollectionExpressionPreference.Never || ShouldSkipAnalysis(context, option.Notification))
return;

var symbol = semanticModel.GetSymbolInfo(objectCreationExpression, cancellationToken).Symbol;
if (symbol is not IMethodSymbol { MethodKind: MethodKind.Constructor, Parameters: [var parameter] } ||
parameter.Type.Name != nameof(IEnumerable<int>))
{
return;
}

if (!Equals(compilation.IEnumerableOfTType(), parameter.Type.OriginalDefinition))
return;

if (!IsArgumentCompatibleWithIEnumerableOfT(semanticModel, argument, out var unwrapArgument, out var useSpread, cancellationToken))
return;

// Make sure we can actually use a collection expression in place of the full invocation.
var allowSemanticsChange = option.Value is CollectionExpressionPreference.WhenTypesLooselyMatch;
if (!CanReplaceWithCollectionExpression(
semanticModel, objectCreationExpression, expressionType, isSingletonInstance: false, allowSemanticsChange, skipVerificationForReplacedNode: true, cancellationToken, out var changesSemantics))
{
return;
}

var locations = ImmutableArray.Create(objectCreationExpression.GetLocation());
var properties = GetDiagnosticProperties(unwrapArgument, useSpread, changesSemantics);

context.ReportDiagnostic(DiagnosticHelper.Create(
Descriptor,
objectCreationExpression.NewKeyword.GetLocation(),
option.Notification,
context.Options,
additionalLocations: locations,
properties));

var additionalUnnecessaryLocations = ImmutableArray.Create(
syntaxTree.GetLocation(TextSpan.FromBounds(
objectCreationExpression.SpanStart,
objectCreationExpression.ArgumentList.OpenParenToken.Span.End)),
objectCreationExpression.ArgumentList.CloseParenToken.GetLocation());

context.ReportDiagnostic(DiagnosticHelper.CreateWithLocationTags(
UnnecessaryCodeDescriptor,
additionalUnnecessaryLocations[0],
NotificationOption2.ForSeverity(UnnecessaryCodeDescriptor.DefaultSeverity),
context.Options,
additionalLocations: locations,
additionalUnnecessaryLocations: additionalUnnecessaryLocations,
properties));
}
}
Loading

0 comments on commit 747210f

Please sign in to comment.