Skip to content

Commit

Permalink
Avoid SyntaxReceiver usage in Visual Studio 2022
Browse files Browse the repository at this point in the history
  • Loading branch information
sharwell committed Aug 7, 2021
1 parent 2422617 commit 3ed338c
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 28 deletions.
115 changes: 88 additions & 27 deletions InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Text;

namespace Refit.Generator
Expand All @@ -18,7 +19,7 @@ namespace Refit.Generator

[Generator]
#if ROSLYN_4
public class InterfaceStubGeneratorV2 : ISourceGenerator
public class InterfaceStubGeneratorV2 : IIncrementalGenerator
#else
public class InterfaceStubGenerator : ISourceGenerator
#endif
Expand All @@ -41,31 +42,48 @@ public class InterfaceStubGenerator : ISourceGenerator
true);
#pragma warning restore RS2008 // Enable analyzer release tracking

public void Execute(GeneratorExecutionContext context)
{
GenerateInterfaceStubs(context);
}
#if !ROSLYN_4

public void GenerateInterfaceStubs(GeneratorExecutionContext context)
public void Execute(GeneratorExecutionContext context)
{
if (context.SyntaxReceiver is not SyntaxReceiver receiver)
return;

context.AnalyzerConfigOptions.GlobalOptions.TryGetValue("build_property.RefitInternalNamespace", out var refitInternalNamespace);
GenerateInterfaceStubs(
context,
static (context, diagnostic) => context.ReportDiagnostic(diagnostic),
static (context, hintName, sourceText) => context.AddSource(hintName, sourceText),
(CSharpCompilation)context.Compilation,
context.AnalyzerConfigOptions,
receiver.CandidateMethods.ToImmutableArray(),
receiver.CandidateInterfaces.ToImmutableArray());
}

#endif

public void GenerateInterfaceStubs<TContext>(
TContext context,
Action<TContext, Diagnostic> reportDiagnostic,
Action<TContext, string, SourceText> addSource,
CSharpCompilation compilation,
AnalyzerConfigOptionsProvider analyzerConfigOptions,
ImmutableArray<MethodDeclarationSyntax> candidateMethods,
ImmutableArray<InterfaceDeclarationSyntax> candidateInterfaces)
{
analyzerConfigOptions.GlobalOptions.TryGetValue("build_property.RefitInternalNamespace", out var refitInternalNamespace);

refitInternalNamespace = $"{refitInternalNamespace ?? string.Empty}RefitInternalGenerated";

// we're going to create a new compilation that contains the attribute.
// TODO: we should allow source generators to provide source during initialize, so that this step isn't required.
var options = (context.Compilation as CSharpCompilation)!.SyntaxTrees[0].Options as CSharpParseOptions;
var compilation = context.Compilation;
var options = (CSharpParseOptions)compilation.SyntaxTrees[0].Options;

var disposableInterfaceSymbol = compilation.GetTypeByMetadataName("System.IDisposable")!;
var httpMethodBaseAttributeSymbol = compilation.GetTypeByMetadataName("Refit.HttpMethodAttribute");

if(httpMethodBaseAttributeSymbol == null)
{
context.ReportDiagnostic(Diagnostic.Create(RefitNotReferenced, null));
reportDiagnostic(context, Diagnostic.Create(RefitNotReferenced, null));
return;
}

Expand All @@ -76,7 +94,7 @@ public void GenerateInterfaceStubs(GeneratorExecutionContext context)
var interfaceToNullableEnabledMap = new Dictionary<INamedTypeSymbol, bool>(SymbolEqualityComparer.Default);
#pragma warning restore RS1024 // Compare symbols correctly
var methodSymbols = new List<IMethodSymbol>();
foreach (var group in receiver.CandidateMethods.GroupBy(m => m.SyntaxTree))
foreach (var group in candidateMethods.GroupBy(m => m.SyntaxTree))
{
var model = compilation.GetSemanticModel(group.Key);
foreach (var method in group)
Expand All @@ -85,7 +103,7 @@ public void GenerateInterfaceStubs(GeneratorExecutionContext context)
var methodSymbol = model.GetDeclaredSymbol(method);
if (IsRefitMethod(methodSymbol, httpMethodBaseAttributeSymbol))
{
var isAnnotated = context.Compilation.Options.NullableContextOptions == NullableContextOptions.Enable ||
var isAnnotated = compilation.Options.NullableContextOptions == NullableContextOptions.Enable ||
model.GetNullableContext(method.SpanStart) == NullableContext.Enabled;
interfaceToNullableEnabledMap[methodSymbol!.ContainingType] = isAnnotated;

Expand All @@ -99,7 +117,7 @@ public void GenerateInterfaceStubs(GeneratorExecutionContext context)

// Look through the candidate interfaces
var interfaceSymbols = new List<INamedTypeSymbol>();
foreach(var group in receiver.CandidateInterfaces.GroupBy(i => i.SyntaxTree))
foreach(var group in candidateInterfaces.GroupBy(i => i.SyntaxTree))
{
var model = compilation.GetSemanticModel(group.Key);
foreach (var iface in group)
Expand Down Expand Up @@ -133,7 +151,7 @@ public void GenerateInterfaceStubs(GeneratorExecutionContext context)
if(!interfaces.Any()) return;


var supportsNullable = ((CSharpParseOptions)context.ParseOptions).LanguageVersion >= LanguageVersion.CSharp8;
var supportsNullable = options.LanguageVersion >= LanguageVersion.CSharp8;

var keyCount = new Dictionary<string, int>();

Expand All @@ -158,10 +176,10 @@ sealed class PreserveAttribute : global::System.Attribute
";


compilation = context.Compilation.AddSyntaxTrees(CSharpSyntaxTree.ParseText(SourceText.From(attributeText, Encoding.UTF8), options));
compilation = compilation.AddSyntaxTrees(CSharpSyntaxTree.ParseText(SourceText.From(attributeText, Encoding.UTF8), options));

// add the attribute text
context.AddSource("PreserveAttribute.g.cs", SourceText.From(attributeText, Encoding.UTF8));
addSource(context, "PreserveAttribute.g.cs", SourceText.From(attributeText, Encoding.UTF8));

// get the newly bound attribute
var preserveAttributeSymbol = compilation.GetTypeByMetadataName($"{refitInternalNamespace}.PreserveAttribute")!;
Expand All @@ -183,7 +201,7 @@ internal static partial class Generated
}}
#pragma warning restore
";
context.AddSource("Generated.g.cs", SourceText.From(generatedClassText, Encoding.UTF8));
addSource(context, "Generated.g.cs", SourceText.From(generatedClassText, Encoding.UTF8));

compilation = compilation.AddSyntaxTrees(CSharpSyntaxTree.ParseText(SourceText.From(generatedClassText, Encoding.UTF8), options));

Expand All @@ -196,14 +214,15 @@ internal static partial class Generated
// with a refit attribute on them. Types may contain other members, without the attribute, which we'll
// need to check for and error out on

var classSource = ProcessInterface(group.Key,
var classSource = ProcessInterface(context,
reportDiagnostic,
group.Key,
group.Value,
preserveAttributeSymbol,
disposableInterfaceSymbol,
httpMethodBaseAttributeSymbol,
supportsNullable,
interfaceToNullableEnabledMap[group.Key],
context);
interfaceToNullableEnabledMap[group.Key]);

var keyName = group.Key.Name;
if(keyCount.TryGetValue(keyName, out var value))
Expand All @@ -212,19 +231,20 @@ internal static partial class Generated
}
keyCount[keyName] = value;

context.AddSource($"{keyName}.g.cs", SourceText.From(classSource, Encoding.UTF8));
addSource(context, $"{keyName}.g.cs", SourceText.From(classSource, Encoding.UTF8));
}

}

string ProcessInterface(INamedTypeSymbol interfaceSymbol,
string ProcessInterface<TContext>(TContext context,
Action<TContext, Diagnostic> reportDiagnostic,
INamedTypeSymbol interfaceSymbol,
List<IMethodSymbol> refitMethods,
ISymbol preserveAttributeSymbol,
ISymbol disposableInterfaceSymbol,
INamedTypeSymbol httpMethodBaseAttributeSymbol,
bool supportsNullable,
bool nullableEnabled,
GeneratorExecutionContext context)
bool nullableEnabled)
{

// Get the class name with the type parameters, then remove the namespace
Expand Down Expand Up @@ -337,7 +357,7 @@ partial class {ns}{classDeclaration}
!method.IsAbstract) // If an interface method has a body, it won't be abstract
continue;

ProcessNonRefitMethod(source, method, context);
ProcessNonRefitMethod(context, reportDiagnostic, source, method);
}

// Handle Dispose
Expand Down Expand Up @@ -464,7 +484,7 @@ void WriteConstraitsForTypeParameter(StringBuilder source, ITypeParameterSymbol

}

void ProcessNonRefitMethod(StringBuilder source, IMethodSymbol methodSymbol, GeneratorExecutionContext context)
void ProcessNonRefitMethod<TContext>(TContext context, Action<TContext, Diagnostic> reportDiagnostic, StringBuilder source, IMethodSymbol methodSymbol)
{
WriteMethodOpening(source, methodSymbol, true);

Expand All @@ -477,7 +497,7 @@ void ProcessNonRefitMethod(StringBuilder source, IMethodSymbol methodSymbol, Gen
foreach(var location in methodSymbol.Locations)
{
var diagnostic = Diagnostic.Create(InvalidRefitMember, location, methodSymbol.ContainingType.Name, methodSymbol.Name);
context.ReportDiagnostic(diagnostic);
reportDiagnostic(context, diagnostic);
}
}

Expand Down Expand Up @@ -521,6 +541,45 @@ bool IsRefitMethod(IMethodSymbol? methodSymbol, INamedTypeSymbol httpMethodAttib
return methodSymbol?.GetAttributes().Any(ad => ad.AttributeClass?.InheritsFromOrEquals(httpMethodAttibute) == true) == true;
}

#if ROSLYN_4

public void Initialize(IncrementalGeneratorInitializationContext context)
{
// We're looking for methods with an attribute that are in an interface
var candidateMethodsProvider = context.SyntaxProvider.CreateSyntaxProvider(
(syntax, cancellationToken) => syntax is MethodDeclarationSyntax { Parent: InterfaceDeclarationSyntax, AttributeLists: { Count: > 0 } },
(context, cancellationToken) => (MethodDeclarationSyntax)context.Node);

// We also look for interfaces that derive from others, so we can see if any base methods contain
// Refit methods
var candidateInterfacesProvider = context.SyntaxProvider.CreateSyntaxProvider(
(syntax, cancellationToken) => syntax is InterfaceDeclarationSyntax { BaseList: not null },
(context, cancellationToken) => (InterfaceDeclarationSyntax)context.Node);

var inputs = candidateMethodsProvider.Collect()
.Combine(candidateInterfacesProvider.Collect())
.Select((combined, cancellationToken) => (candidateMethods: combined.Left, candidateInterfaces: combined.Right))
.Combine(context.AnalyzerConfigOptionsProvider)
.Combine(context.CompilationProvider)
.Select((combined, cancellationToken) => (combined.Left.Left.candidateMethods, combined.Left.Left.candidateInterfaces, analyzerConfigOptions: combined.Left.Right, compilation: combined.Right));

context.RegisterSourceOutput(
inputs,
(context, collectedValues) =>
{
GenerateInterfaceStubs(
context,
static (context, diagnostic) => context.ReportDiagnostic(diagnostic),
static (context, hintName, sourceText) => context.AddSource(hintName, sourceText),
(CSharpCompilation)collectedValues.compilation,
collectedValues.analyzerConfigOptions,
collectedValues.candidateMethods,
collectedValues.candidateInterfaces);
});
}

#else

public void Initialize(GeneratorInitializationContext context)
{
context.RegisterForSyntaxNotifications(() => new SyntaxReceiver());
Expand Down Expand Up @@ -550,5 +609,7 @@ methodDeclarationSyntax.Parent is InterfaceDeclarationSyntax &&
}
}
}

#endif
}
}
2 changes: 1 addition & 1 deletion Refit.Tests/InterfaceStubGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

using Task = System.Threading.Tasks.Task;
using VerifyCS = Refit.Tests.CSharpSourceGeneratorVerifier<Refit.Generator.InterfaceStubGenerator>;
using VerifyCSV2 = Refit.Tests.CSharpSourceGeneratorVerifier<Refit.Generator.InterfaceStubGeneratorV2>;
using VerifyCSV2 = Refit.Tests.CSharpIncrementalSourceGeneratorVerifier<Refit.Generator.InterfaceStubGeneratorV2>;

namespace Refit.Tests
{
Expand Down

0 comments on commit 3ed338c

Please sign in to comment.