Skip to content

Commit

Permalink
Fix crash with simplify delegate invoke (#76427)
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrusNajmabadi authored Dec 16, 2024
2 parents 070b69d + d237945 commit 192c9cc
Show file tree
Hide file tree
Showing 14 changed files with 215 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@

using System;
using System.Collections.Immutable;
using System.Threading;
using Microsoft.CodeAnalysis.CodeStyle;
using Microsoft.CodeAnalysis.CSharp.CodeStyle;
using Microsoft.CodeAnalysis.CSharp.Diagnostics;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.LanguageService;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Text;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.CSharp.InvokeDelegateWithConditionalAccess;

Expand All @@ -23,16 +25,13 @@ internal static class Constants
}

[DiagnosticAnalyzer(LanguageNames.CSharp)]
internal class InvokeDelegateWithConditionalAccessAnalyzer : AbstractBuiltInCodeStyleDiagnosticAnalyzer
internal sealed class InvokeDelegateWithConditionalAccessAnalyzer()
: AbstractBuiltInCodeStyleDiagnosticAnalyzer(
IDEDiagnosticIds.InvokeDelegateWithConditionalAccessId,
EnforceOnBuildValues.InvokeDelegateWithConditionalAccess,
CSharpCodeStyleOptions.PreferConditionalDelegateCall,
new LocalizableResourceString(nameof(CSharpAnalyzersResources.Delegate_invocation_can_be_simplified), CSharpAnalyzersResources.ResourceManager, typeof(CSharpAnalyzersResources)))
{
public InvokeDelegateWithConditionalAccessAnalyzer()
: base(IDEDiagnosticIds.InvokeDelegateWithConditionalAccessId,
EnforceOnBuildValues.InvokeDelegateWithConditionalAccess,
CSharpCodeStyleOptions.PreferConditionalDelegateCall,
new LocalizableResourceString(nameof(CSharpAnalyzersResources.Delegate_invocation_can_be_simplified), CSharpAnalyzersResources.ResourceManager, typeof(CSharpAnalyzersResources)))
{
}

protected override void InitializeWorker(AnalysisContext context)
=> context.RegisterSyntaxNodeAction(SyntaxNodeAction, SyntaxKind.IfStatement);

Expand Down Expand Up @@ -110,7 +109,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext)
}

private bool TryCheckSingleIfStatementForm(
SyntaxNodeAnalysisContext syntaxContext,
SyntaxNodeAnalysisContext context,
IfStatementSyntax ifStatement,
BinaryExpressionSyntax condition,
ExpressionStatementSyntax expressionStatement,
Expand All @@ -125,17 +124,17 @@ private bool TryCheckSingleIfStatementForm(
? condition.Right
: condition.Left;

if (InvocationExpressionIsEquivalent(expr, invocationExpression))
if (InvocationExpressionIsEquivalent(expr))
{
// Looks good!
var tree = syntaxContext.SemanticModel.SyntaxTree;
var additionalLocations = ImmutableArray.Create<Location>(
Location.Create(tree, ifStatement.Span),
Location.Create(tree, expressionStatement.Span));

var tree = context.SemanticModel.SyntaxTree;
ReportDiagnostics(
syntaxContext, ifStatement, ifStatement,
expressionStatement, notificationOption, additionalLocations,
context,
ifStatement,
ifStatement,
expressionStatement,
notificationOption,
[Location.Create(tree, ifStatement.Span), Location.Create(tree, expressionStatement.Span)],
Constants.SingleIfStatementForm);

return true;
Expand All @@ -144,7 +143,7 @@ private bool TryCheckSingleIfStatementForm(

return false;

static bool InvocationExpressionIsEquivalent(ExpressionSyntax expression, InvocationExpressionSyntax invocationExpression)
bool InvocationExpressionIsEquivalent(ExpressionSyntax expression)
{
// expr(...)
if (SyntaxFactory.AreEquivalent(expression, invocationExpression.Expression, topLevel: false))
Expand All @@ -154,7 +153,10 @@ static bool InvocationExpressionIsEquivalent(ExpressionSyntax expression, Invoca
if (invocationExpression.Expression is MemberAccessExpressionSyntax { Name: IdentifierNameSyntax { Identifier.ValueText: nameof(Action.Invoke) } } memberAccessExpression &&
SyntaxFactory.AreEquivalent(expression, memberAccessExpression.Expression, topLevel: false))
{
return true;
// note: in this case, we have to make sure we're actually calling on some delegate type, not a random
// class with an 'Invoke' method.
var type = context.SemanticModel.GetTypeInfo(expression, context.CancellationToken).Type;
return type.IsDelegateType();
}

return false;
Expand Down Expand Up @@ -224,10 +226,9 @@ private bool TryCheckVariableAndIfStatementForm(
cancellationToken.ThrowIfCancellationRequested();

// look for the form "if (a != null)" or "if (null != a)"
if (!ifStatement.Parent.IsKind(SyntaxKind.Block))
{
var parentBlock = CSharpBlockFacts.Instance.GetImmediateParentExecutableBlockForStatement(ifStatement);
if (parentBlock is null)
return false;
}

if (!IsNullCheckExpression(condition.Left, condition.Right) &&
!IsNullCheckExpression(condition.Right, condition.Left))
Expand All @@ -249,62 +250,39 @@ private bool TryCheckVariableAndIfStatementForm(
if (invocationName is null)
return false;

var conditionName = condition.Left is IdentifierNameSyntax
? (IdentifierNameSyntax)condition.Left
var conditionName = condition.Left is IdentifierNameSyntax leftIdentifier
? leftIdentifier
: (IdentifierNameSyntax)condition.Right;

if (!Equals(conditionName.Identifier.ValueText, invocationName.Identifier.ValueText))
{
return false;
}

// Now make sure the previous statement is "var a = ..."
var parentBlock = (BlockSyntax)ifStatement.Parent;
var ifIndex = parentBlock.Statements.IndexOf(ifStatement);
var blockStatements = CSharpBlockFacts.Instance.GetExecutableBlockStatements(parentBlock);
var ifIndex = blockStatements.IndexOf(ifStatement);
if (ifIndex == 0)
{
return false;
}

var previousStatement = parentBlock.Statements[ifIndex - 1];
if (previousStatement is not LocalDeclarationStatementSyntax localDeclarationStatement)
{
return false;
}

var variableDeclaration = localDeclarationStatement.Declaration;

if (variableDeclaration.Variables.Count != 1)
{
return false;
}

var declarator = variableDeclaration.Variables[0];
if (declarator.Initializer == null)
{
var previousStatement = blockStatements[ifIndex - 1];
if (previousStatement is not LocalDeclarationStatementSyntax { Declaration.Variables: [{ Initializer.Value: { } initializer } declarator] } localDeclarationStatement)
return false;
}

cancellationToken.ThrowIfCancellationRequested();
if (!Equals(declarator.Identifier.ValueText, conditionName.Identifier.ValueText))
{
return false;
}

// Syntactically this looks good. Now make sure that the local is a delegate type.
var semanticModel = syntaxContext.SemanticModel;

// The initializer can't be inlined if it's an actual lambda/method reference.
// These cannot be invoked with `?.` (only delegate *values* can be).
var initializer = declarator.Initializer.Value.WalkDownParentheses();
initializer = initializer.WalkDownParentheses();
if (initializer is AnonymousFunctionExpressionSyntax)
return false;

var initializerSymbol = semanticModel.GetSymbolInfo(initializer, cancellationToken).GetAnySymbol();
if (initializerSymbol is IMethodSymbol)
{
return false;
}

var localSymbol = (ILocalSymbol)semanticModel.GetRequiredDeclaredSymbol(declarator, cancellationToken);

Expand All @@ -316,15 +294,14 @@ private bool TryCheckVariableAndIfStatementForm(
return false;

// Looks good!
var tree = semanticModel.SyntaxTree;
var additionalLocations = ImmutableArray.Create(
Location.Create(tree, localDeclarationStatement.Span),
Location.Create(tree, ifStatement.Span),
Location.Create(tree, expressionStatement.Span));

ReportDiagnostics(syntaxContext,
localDeclarationStatement, ifStatement, expressionStatement,
notificationOption, additionalLocations, Constants.VariableAndIfStatementForm);
ReportDiagnostics(
syntaxContext,
localDeclarationStatement,
ifStatement,
expressionStatement,
notificationOption,
[localDeclarationStatement.GetLocation(), ifStatement.GetLocation(), expressionStatement.GetLocation()],
Constants.VariableAndIfStatementForm);

return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,11 @@ private void AnalyzeSyntax(SyntaxNodeAnalysisContext context)
var outermostUsing = (UsingStatementSyntax)context.Node;
var semanticModel = context.SemanticModel;

var parentBlockLike = outermostUsing.Parent;
if (parentBlockLike is GlobalStatementSyntax)
parentBlockLike = parentBlockLike.Parent;
var parentBlockLike = CSharpBlockFacts.Instance.GetImmediateParentExecutableBlockForStatement(outermostUsing);

// Don't offer on a using statement that is parented by another using statement. We'll just offer on the
// topmost using statement.
// Don't offer on a using statement that is parented by another using statement. We'll just offer on the topmost
// using statement. Also, this is only offered in a block and compilation unit. Simple using statements are
// not allowed within switch sections.
if (parentBlockLike is not BlockSyntax and not CompilationUnitSyntax)
return;

Expand Down Expand Up @@ -179,7 +178,7 @@ private static bool PreservesSemantics(
UsingStatementSyntax innermostUsing,
CancellationToken cancellationToken)
{
var statements = (IReadOnlyList<StatementSyntax>)CSharpBlockFacts.Instance.GetExecutableBlockStatements(parentBlockLike);
var statements = CSharpBlockFacts.Instance.GetExecutableBlockStatements(parentBlockLike);
var index = statements.IndexOf(outermostUsing);

return UsingValueDoesNotLeakToFollowingStatements(semanticModel, statements, index, cancellationToken) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,12 @@ private static void HandleSingleIfStatementForm(
Diagnostic diagnostic,
CancellationToken cancellationToken)
{
var root = editor.OriginalRoot;

var ifStatementLocation = diagnostic.AdditionalLocations[0];
var expressionStatementLocation = diagnostic.AdditionalLocations[1];

var ifStatement = (IfStatementSyntax)root.FindNode(ifStatementLocation.SourceSpan);
// May be at the top level, pass `getInnermostNodeForTie: true` to peer into global statement.
var ifStatement = (IfStatementSyntax)diagnostic.AdditionalLocations[0].FindNode(getInnermostNodeForTie: true, cancellationToken);
cancellationToken.ThrowIfCancellationRequested();

var expressionStatement = (ExpressionStatementSyntax)root.FindNode(expressionStatementLocation.SourceSpan);
// Always under another statement.block. So getInnermostNodeForTie: true` is not necessary, but keeps things consistent.
var expressionStatement = (ExpressionStatementSyntax)diagnostic.AdditionalLocations[1].FindNode(getInnermostNodeForTie: true, cancellationToken);
cancellationToken.ThrowIfCancellationRequested();

var invocationExpression = (InvocationExpressionSyntax)expressionStatement.Expression;
Expand Down Expand Up @@ -116,23 +113,19 @@ private static void HandleSingleIfStatementForm(
private static void HandleVariableAndIfStatementForm(
SyntaxEditor editor, Diagnostic diagnostic, CancellationToken cancellationToken)
{
var root = editor.OriginalRoot;

var localDeclarationLocation = diagnostic.AdditionalLocations[0];
var ifStatementLocation = diagnostic.AdditionalLocations[1];
var expressionStatementLocation = diagnostic.AdditionalLocations[2];

var localDeclarationStatement = (LocalDeclarationStatementSyntax)root.FindNode(localDeclarationLocation.SourceSpan);
// May be at the top level, pass `getInnermostNodeForTie: true` to peer into global statement.
var localDeclarationStatement = (LocalDeclarationStatementSyntax)diagnostic.AdditionalLocations[0].FindNode(getInnermostNodeForTie: true, cancellationToken);
cancellationToken.ThrowIfCancellationRequested();

var ifStatement = (IfStatementSyntax)root.FindNode(ifStatementLocation.SourceSpan);
// May be at the top level, pass `getInnermostNodeForTie: true` to peer into global statement.
var ifStatement = (IfStatementSyntax)diagnostic.AdditionalLocations[1].FindNode(getInnermostNodeForTie: true, cancellationToken);
cancellationToken.ThrowIfCancellationRequested();

var expressionStatement = (ExpressionStatementSyntax)root.FindNode(expressionStatementLocation.SourceSpan);
// Always under another statement.block. So getInnermostNodeForTie: true` is not necessary, but keeps things consistent.
var expressionStatement = (ExpressionStatementSyntax)diagnostic.AdditionalLocations[2].FindNode(getInnermostNodeForTie: true, cancellationToken);
cancellationToken.ThrowIfCancellationRequested();

var invocationExpression = (InvocationExpressionSyntax)expressionStatement.Expression;
var parentBlock = (BlockSyntax)localDeclarationStatement.GetRequiredParent();

var invokeName =
invocationExpression.Expression is MemberAccessExpressionSyntax { Name: IdentifierNameSyntax { Identifier.ValueText: nameof(Action.Invoke) } } memberAccessExpression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ private static SyntaxNode RewriteBlock(
SyntaxNode currentBlockLike,
ISet<UsingStatementSyntax> topmostUsingStatements)
{
var originalBlockStatements = (IReadOnlyList<StatementSyntax>)CSharpBlockFacts.Instance.GetExecutableBlockStatements(originalBlockLike);
var currentBlockStatements = (IReadOnlyList<StatementSyntax>)CSharpBlockFacts.Instance.GetExecutableBlockStatements(currentBlockLike);
var originalBlockStatements = CSharpBlockFacts.Instance.GetExecutableBlockStatements(originalBlockLike);
var currentBlockStatements = CSharpBlockFacts.Instance.GetExecutableBlockStatements(currentBlockLike);

if (originalBlockStatements.Count == currentBlockStatements.Count)
{
Expand Down
Loading

0 comments on commit 192c9cc

Please sign in to comment.