Skip to content

Commit

Permalink
Extract helpers to AsyncHelper, and sort AssertThrowsShouldNotBeUsedF…
Browse files Browse the repository at this point in the history
…orAsyncThrowsCheckFixer members
  • Loading branch information
bradwilson committed Nov 9, 2023
1 parent 4ea59ac commit 7498598
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 61 deletions.
49 changes: 49 additions & 0 deletions src/xunit.analyzers.fixes/Utility/AsyncHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;

namespace Xunit.Analyzers.Fixes;

public static class AsyncHelper
{
/// <summary>
/// Get a method's modifiers that include the async keyword.
/// </summary>
public static SyntaxTokenList GetModifiersWithAsyncKeywordAdded(MethodDeclarationSyntax method) =>
method.Modifiers.Any(SyntaxKind.AsyncKeyword)
? method.Modifiers
: method.Modifiers.Add(Token(SyntaxKind.AsyncKeyword));

/// <summary>
/// Get the syntax type for an updated return type to support using async.
/// </summary>
public static async Task<TypeSyntax?> GetReturnType(
MethodDeclarationSyntax method,
InvocationExpressionSyntax invocation,
Document document,
DocumentEditor editor,
CancellationToken cancellationToken)
{
// Consider the case where a custom awaiter type is awaited
if (invocation.Parent.IsKind(SyntaxKind.AwaitExpression))
return method.ReturnType;

var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
if (semanticModel is null)
return null;

var methodSymbol = semanticModel.GetSymbolInfo(method.ReturnType, cancellationToken).Symbol as ITypeSymbol;
var taskType = TypeSymbolFactory.Task(semanticModel.Compilation);
if (taskType is null)
return null;

if (taskType.IsAssignableFrom(methodSymbol))
return method.ReturnType;

return editor.Generator.TypeExpression(taskType) as TypeSyntax;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,34 @@ public AssertThrowsShouldNotBeUsedForAsyncThrowsCheckFixer() :
base(Descriptors.X2014_AssertThrowsShouldNotBeUsedForAsyncThrowsCheck.Id)
{ }

static ExpressionSyntax GetAsyncThrowsInvocation(
InvocationExpressionSyntax invocation,
string memberName,
MemberAccessExpressionSyntax memberAccess)
{
var asyncThrowsInvocation =
invocation
.WithExpression(memberAccess.WithName(GetName(memberName, memberAccess)))
.WithArgumentList(invocation.ArgumentList);

if (invocation.Parent.IsKind(SyntaxKind.AwaitExpression))
return asyncThrowsInvocation;

return
AwaitExpression(asyncThrowsInvocation.WithoutLeadingTrivia())
.WithLeadingTrivia(invocation.GetLeadingTrivia());
}

static SimpleNameSyntax GetName(
string memberName,
MemberAccessExpressionSyntax memberAccess)
{
if (memberAccess.Name is not GenericNameSyntax genericNameSyntax)
return IdentifierName(memberName);

return GenericName(IdentifierName(memberName).Identifier, genericNameSyntax.TypeArgumentList);
}

public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
var root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -66,8 +94,8 @@ static async Task<Document> UseAsyncThrowsCheck(

if (invocation.Expression is MemberAccessExpressionSyntax memberAccess)
{
var modifiers = GetModifiersWithAsyncKeywordAdded(method);
var returnType = await GetReturnType(method, invocation, document, editor, cancellationToken);
var modifiers = AsyncHelper.GetModifiersWithAsyncKeywordAdded(method);
var returnType = await AsyncHelper.GetReturnType(method, invocation, document, editor, cancellationToken);
var asyncThrowsInvocation = GetAsyncThrowsInvocation(invocation, replacement, memberAccess);

if (returnType is not null)
Expand All @@ -82,63 +110,4 @@ static async Task<Document> UseAsyncThrowsCheck(

return editor.GetChangedDocument();
}

static SyntaxTokenList GetModifiersWithAsyncKeywordAdded(MethodDeclarationSyntax method) =>
method.Modifiers.Any(SyntaxKind.AsyncKeyword)
? method.Modifiers
: method.Modifiers.Add(Token(SyntaxKind.AsyncKeyword));

static async Task<TypeSyntax?> GetReturnType(
MethodDeclarationSyntax method,
InvocationExpressionSyntax invocation,
Document document,
DocumentEditor editor,
CancellationToken cancellationToken)
{
// Consider the case where a custom awaiter type is awaited
if (invocation.Parent.IsKind(SyntaxKind.AwaitExpression))
return method.ReturnType;

var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
if (semanticModel is null)
return null;

var methodSymbol = semanticModel.GetSymbolInfo(method.ReturnType, cancellationToken).Symbol as ITypeSymbol;
var taskType = TypeSymbolFactory.Task(semanticModel.Compilation);
if (taskType is null)
return null;

if (taskType.IsAssignableFrom(methodSymbol))
return method.ReturnType;

return editor.Generator.TypeExpression(taskType) as TypeSyntax;
}

static ExpressionSyntax GetAsyncThrowsInvocation(
InvocationExpressionSyntax invocation,
string memberName,
MemberAccessExpressionSyntax memberAccess)
{
var asyncThrowsInvocation =
invocation
.WithExpression(memberAccess.WithName(GetName(memberName, memberAccess)))
.WithArgumentList(invocation.ArgumentList);

if (invocation.Parent.IsKind(SyntaxKind.AwaitExpression))
return asyncThrowsInvocation;

return
AwaitExpression(asyncThrowsInvocation.WithoutLeadingTrivia())
.WithLeadingTrivia(invocation.GetLeadingTrivia());
}

static SimpleNameSyntax GetName(
string memberName,
MemberAccessExpressionSyntax memberAccess)
{
if (memberAccess.Name is not GenericNameSyntax genericNameSyntax)
return IdentifierName(memberName);

return GenericName(IdentifierName(memberName).Identifier, genericNameSyntax.TypeArgumentList);
}
}

0 comments on commit 7498598

Please sign in to comment.