Skip to content

Commit

Permalink
Add SupportedOSPlatformAttribute to generated code
Browse files Browse the repository at this point in the history
Extern methods and COM interfaces-as-structs get it. Genuine COM interfaces don't get it because the attribute forbids application on interfaces.

Closes #40
  • Loading branch information
AArnott committed Mar 17, 2021
1 parent 4c3c8f4 commit 54cf79a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 15 deletions.
44 changes: 39 additions & 5 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ public class Generator : IDisposable
private static readonly TypeSyntax SafeHandleTypeSyntax = IdentifierName("SafeHandle");
private static readonly IdentifierNameSyntax IntPtrTypeSyntax = IdentifierName(nameof(IntPtr));
private static readonly AttributeSyntax PreserveSigAttribute = Attribute(IdentifierName("PreserveSig"));
private static readonly AttributeSyntax SupportedOSPlatformAttribute = Attribute(IdentifierName("SupportedOSPlatform"));
private static readonly AttributeListSyntax DefaultDllImportSearchPathsAttributeList = AttributeList().AddAttributes(
Attribute(IdentifierName("DefaultDllImportSearchPaths")).AddArgumentListArguments(AttributeArgument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(nameof(DllImportSearchPath)), IdentifierName(nameof(DllImportSearchPath.System32))))));

Expand Down Expand Up @@ -320,6 +321,7 @@ public class Generator : IDisposable
private readonly CSharpCompilation? compilation;
private readonly CSharpParseOptions? parseOptions;
private readonly bool canCallCreateSpan;
private readonly bool generateSupportedOSPlatformAttributes;

/// <summary>
/// Initializes a new instance of the <see cref="Generator"/> class.
Expand All @@ -336,6 +338,7 @@ public Generator(Stream metadataLibraryStream, GeneratorOptions? options = null,
this.parseOptions = parseOptions;

this.canCallCreateSpan = this.compilation?.GetTypeByMetadataName(typeof(MemoryMarshal).FullName)?.GetMembers("CreateSpan").Any() is true;
this.generateSupportedOSPlatformAttributes = this.compilation?.GetTypeByMetadataName("System.Runtime.Versioning.SupportedOSPlatformAttribute") is object;

if (options.AllowMarshaling)
{
Expand Down Expand Up @@ -843,16 +846,25 @@ public IReadOnlyDictionary<string, CompilationUnitSyntax> GetCompilationUnits(Ca
}
}

var usingDirectives = new List<UsingDirectiveSyntax>
{
UsingDirective(AliasQualifiedName(IdentifierName(Token(SyntaxKind.GlobalKeyword)), IdentifierName(nameof(System)))),
UsingDirective(AliasQualifiedName(IdentifierName(Token(SyntaxKind.GlobalKeyword)), IdentifierName(nameof(System) + "." + nameof(System.Diagnostics)))),
UsingDirective(ParseName(GlobalNamespacePrefix + SystemRuntimeCompilerServices)),
UsingDirective(ParseName(GlobalNamespacePrefix + SystemRuntimeInteropServices)),
};

if (this.generateSupportedOSPlatformAttributes)
{
usingDirectives.Add(UsingDirective(ParseName(GlobalNamespacePrefix + "System.Runtime.Versioning")));
}

var normalizedResults = new Dictionary<string, CompilationUnitSyntax>(StringComparer.OrdinalIgnoreCase);
results.AsParallel().WithCancellation(cancellationToken).ForAll(kv =>
{
var compilationUnit = CompilationUnit()
.AddMembers(
kv.Value.AddUsings(
UsingDirective(AliasQualifiedName(IdentifierName(Token(SyntaxKind.GlobalKeyword)), IdentifierName(nameof(System)))),
UsingDirective(AliasQualifiedName(IdentifierName(Token(SyntaxKind.GlobalKeyword)), IdentifierName(nameof(System) + "." + nameof(System.Diagnostics)))),
UsingDirective(ParseName(GlobalNamespacePrefix + SystemRuntimeCompilerServices)),
UsingDirective(ParseName(GlobalNamespacePrefix + SystemRuntimeInteropServices))))
kv.Value.AddUsings(usingDirectives.ToArray()))
.WithLeadingTrivia(ParseLeadingTrivia(AutoGeneratedHeader).Add(
Trivia(PragmaWarningDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true).AddErrorCodes(WarningsToSuppressInGeneratedCode.Select(code => IdentifierName(code)).ToArray()))))
.NormalizeWhitespace();
Expand Down Expand Up @@ -2197,6 +2209,11 @@ private void DeclareExternMethod(MethodDefinitionHandle methodDefinitionHandle)
Token(SyntaxKind.SemicolonToken));
methodDeclaration = returnType.AddReturnMarshalAs(methodDeclaration);

if (this.GetSupportedOSPlatformAttribute(methodDefinition.GetCustomAttributes()) is AttributeSyntax supportedOSPlatformAttribute)
{
methodDeclaration = methodDeclaration.AddAttributeLists(AttributeList().AddAttributes(supportedOSPlatformAttribute));
}

// Add documentation if we can find it.
methodDeclaration = AddApiDocumentation(entrypoint ?? methodName, methodDeclaration);

Expand All @@ -2216,6 +2233,18 @@ private void DeclareExternMethod(MethodDefinitionHandle methodDefinitionHandle)
}
}

private AttributeSyntax? GetSupportedOSPlatformAttribute(CustomAttributeHandleCollection attributes)
{
AttributeSyntax? supportedOSPlatformAttribute = null;
if (this.generateSupportedOSPlatformAttributes && this.FindInteropDecorativeAttribute(attributes, "SupportedOSPlatformAttribute") is CustomAttribute templateOSPlatformAttribute)
{
CustomAttributeValue<TypeSyntax> args = templateOSPlatformAttribute.DecodeValue(CustomAttributeTypeProvider.Instance);
supportedOSPlatformAttribute = SupportedOSPlatformAttribute.AddArgumentListArguments(AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal((string)args.FixedArguments[0].Value!))));
}

return supportedOSPlatformAttribute;
}

private FieldDeclarationSyntax DeclareField(FieldDefinitionHandle fieldDefHandle)
{
FieldDefinition fieldDef = this.mr.GetFieldDefinition(fieldDefHandle);
Expand Down Expand Up @@ -2443,6 +2472,11 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinition typeDef, I
iface = iface.AddAttributeLists(AttributeList().AddAttributes(GUID(guid)));
}

if (this.GetSupportedOSPlatformAttribute(typeDef.GetCustomAttributes()) is AttributeSyntax supportedOSPlatformAttribute)
{
iface = iface.AddAttributeLists(AttributeList().AddAttributes(supportedOSPlatformAttribute));
}

return iface;
}

Expand Down
62 changes: 52 additions & 10 deletions test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
Expand All @@ -28,7 +26,7 @@ public class GeneratorTests : IDisposable, IAsyncLifetime
private readonly ITestOutputHelper logger;
private readonly FileStream metadataStream;
private CSharpCompilation compilation;
private CSharpCompilation fastSpanCompilation;
private CSharpCompilation net50Compilation;
private CSharpParseOptions parseOptions;
private Generator? generator;

Expand All @@ -43,7 +41,7 @@ public GeneratorTests(ITestOutputHelper logger)

// set in InitializeAsync
this.compilation = null!;
this.fastSpanCompilation = null!;
this.net50Compilation = null!;
}

public async Task InitializeAsync()
Expand All @@ -52,8 +50,8 @@ public async Task InitializeAsync()
ReferenceAssemblies.NetStandard.NetStandard20
.AddPackages(ImmutableArray.Create(new PackageIdentity("System.Memory", "4.5.4"))));

this.fastSpanCompilation = await this.CreateCompilationAsync(
ReferenceAssemblies.NetStandard.NetStandard21);
this.net50Compilation = await this.CreateCompilationAsync(
ReferenceAssemblies.Net.Net50);
}

public Task DisposeAsync() => Task.CompletedTask;
Expand All @@ -76,13 +74,55 @@ public void TryGetEnumName(string candidate, string? declaringEnum)
Assert.Equal(declaringEnum, actualDeclaringEnum);
}

[Fact]
public void SimplestMethod()
[Theory, PairwiseData]
public void SimplestMethod(bool net50)
{
if (net50)
{
this.compilation = this.net50Compilation;
}

this.generator = new Generator(this.metadataStream, DefaultTestGeneratorOptions, this.compilation, this.parseOptions);
Assert.True(this.generator.TryGenerateExternMethod("GetTickCount"));
const string methodName = "GetTickCount";
Assert.True(this.generator.TryGenerateExternMethod(methodName));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();

var generatedMethod = this.FindGeneratedMethod(methodName).Single();
if (net50)
{
Assert.Contains(generatedMethod.AttributeLists, this.IsAttributePresent);
}
else
{
Assert.DoesNotContain(generatedMethod.AttributeLists, this.IsAttributePresent);
}
}

[Theory, PairwiseData]
public void COMInterfaceWithSupportedOSPlatform(bool net50, bool allowMarshaling)
{
if (net50)
{
this.compilation = this.net50Compilation;
}

const string typeName = "IInkCursors";
this.generator = new Generator(this.metadataStream, DefaultTestGeneratorOptions with { AllowMarshaling = allowMarshaling }, this.compilation, this.parseOptions);
Assert.True(this.generator.TryGenerateType(typeName));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();

var iface = this.FindGeneratedType(typeName).Single();

if (net50 && !allowMarshaling)
{
Assert.Contains(iface.AttributeLists, this.IsAttributePresent);
}
else
{
Assert.DoesNotContain(iface.AttributeLists, this.IsAttributePresent);
}
}

[Theory]
Expand Down Expand Up @@ -597,7 +637,7 @@ internal static unsafe ref uint ItemRef(this ref MainAVIHeader.__uint_4 @this, i
}
";

this.compilation = this.fastSpanCompilation;
this.compilation = this.net50Compilation;
this.AssertGeneratedType("MainAVIHeader", expected, expectedIndexer);
}

Expand Down Expand Up @@ -754,6 +794,8 @@ private void AssertGeneratedType(string apiName, string expectedSyntax, string?
}
}

private bool IsAttributePresent(AttributeListSyntax al) => al.Attributes.Any(a => a.Name.ToString() == "SupportedOSPlatform");

private async Task<CSharpCompilation> CreateCompilationAsync(ReferenceAssemblies references)
{
ImmutableArray<MetadataReference> metadataReferences = await references
Expand Down

0 comments on commit 54cf79a

Please sign in to comment.