diff --git a/src/Microsoft.Windows.CsWin32/Generator.cs b/src/Microsoft.Windows.CsWin32/Generator.cs index 451f9696..dfa7eee1 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.cs @@ -302,7 +302,7 @@ public class Generator : IDisposable private readonly bool generateDefaultDllImportSearchPathsAttribute; private readonly GeneratedCode committedCode = new(); private readonly GeneratedCode volatileCode; - private readonly IdentifierNameSyntax constantsClassName; + private readonly IdentifierNameSyntax methodsAndConstantsClassName; private bool needsWinRTCustomMarshaler; /// @@ -356,7 +356,7 @@ public Generator(string metadataLibraryPath, Docs? docs, GeneratorOptions option this.functionPointerTypeSettings = this.generalTypeSettings with { QualifyNames = true }; this.errorMessageTypeSettings = this.generalTypeSettings with { QualifyNames = true }; - this.constantsClassName = IdentifierName(options.ConstantsClassName); + this.methodsAndConstantsClassName = IdentifierName(options.ClassName); } private enum FriendlyOverloadOf @@ -393,37 +393,23 @@ private enum FriendlyOverloadOf private bool WideCharOnly => this.options.WideCharOnly; - private bool GroupByModule => string.IsNullOrEmpty(this.options.MethodsClassName); - private string Namespace => this.InputAssemblyName; - private string SingleClassName => this.options.MethodsClassName ?? throw new InvalidOperationException("Not in one-class mode."); - private SyntaxKind Visibility => this.options.Public ? SyntaxKind.PublicKeyword : SyntaxKind.InternalKeyword; private IEnumerable NamespaceMembers { get { - IEnumerable result = this.GroupByModule - ? this.ExternMethodsByModuleClassName.Select(kv => - ClassDeclaration(Identifier(GetClassNameForModule(kv.Key))) - .AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.PartialKeyword)) - .AddMembers(kv.ToArray())) - : from entry in this.committedCode.MembersByModule - select ClassDeclaration(Identifier(this.SingleClassName)) + IEnumerable result = + from entry in this.committedCode.MembersByModule + select ClassDeclaration(Identifier(this.options.ClassName)) .AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.PartialKeyword)) .AddMembers(entry.ToArray()) .WithLeadingTrivia(ParseLeadingTrivia(string.Format(CultureInfo.InvariantCulture, PartialPInvokeContentComment, entry.Key))) - .WithAdditionalAnnotations(new SyntaxAnnotation(SimpleFileNameAnnotation, $"{this.SingleClassName}.{entry.Key}")); + .WithAdditionalAnnotations(new SyntaxAnnotation(SimpleFileNameAnnotation, $"{this.options.ClassName}.{entry.Key}")); result = result.Concat(this.committedCode.GeneratedTypes); - ClassDeclarationSyntax constantClass = this.DeclareConstantDefiningClass(); - if (constantClass.Members.Count > 0) - { - result = result.Concat(new MemberDeclarationSyntax[] { constantClass }); - } - ClassDeclarationSyntax inlineArrayIndexerExtensionsClass = this.DeclareInlineArrayIndexerExtensionsClass(); if (inlineArrayIndexerExtensionsClass.Members.Count > 0) { @@ -436,6 +422,11 @@ select ClassDeclaration(Identifier(this.SingleClassName)) result = result.Concat(new MemberDeclarationSyntax[] { comInterfaceFriendlyExtensionsClass }); } + if (this.committedCode.Fields.Any()) + { + result = result.Concat(new MemberDeclarationSyntax[] { this.DeclareConstantDefiningClass() }); + } + return result; } } @@ -1391,9 +1382,7 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle) string releaseMethodModule = this.GetNormalizedModuleName(releaseMethodDef.GetImport()); var safeHandleTypeIdentifier = IdentifierName(safeHandleClassName); - safeHandleType = this.GroupByModule - ? QualifiedName(IdentifierName(releaseMethodModule), safeHandleTypeIdentifier) - : safeHandleTypeIdentifier; + safeHandleType = safeHandleTypeIdentifier; MethodSignature releaseMethodSignature = releaseMethodDef.DecodeSignature(SignatureHandleProvider.Instance, null); var releaseMethodParameterType = releaseMethodSignature.ParameterTypes[0].ToTypeSyntax(this.externSignatureTypeSettings, default); @@ -1478,7 +1467,7 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle) ExpressionSyntax releaseInvocation = InvocationExpression( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(this.GroupByModule ? releaseMethodModule : this.SingleClassName), + IdentifierName(this.options.ClassName), IdentifierName(renamedReleaseMethod ?? releaseMethod)), ArgumentList().AddArguments(Argument(CastExpression(releaseMethodParameterType.Type, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ThisExpression(), IdentifierName("handle")))))); BlockSyntax? releaseBlock = null; @@ -1523,12 +1512,12 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle) break; case "NTSTATUS": this.TryGenerateConstantOrThrow("STATUS_SUCCESS"); - ExpressionSyntax statusSuccess = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, this.constantsClassName, IdentifierName("STATUS_SUCCESS")); + ExpressionSyntax statusSuccess = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, this.methodsAndConstantsClassName, IdentifierName("STATUS_SUCCESS")); releaseInvocation = BinaryExpression(SyntaxKind.EqualsExpression, releaseInvocation, statusSuccess); break; case "HRESULT": this.TryGenerateConstantOrThrow("S_OK"); - ExpressionSyntax ok = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, this.constantsClassName, IdentifierName("S_OK")); + ExpressionSyntax ok = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, this.methodsAndConstantsClassName, IdentifierName("S_OK")); releaseInvocation = BinaryExpression(SyntaxKind.EqualsExpression, releaseInvocation, ok); break; default: @@ -1555,16 +1544,11 @@ internal void RequestConstant(FieldDefinitionHandle fieldDefHandle) .AddMembers(members.ToArray()) .WithLeadingTrivia(ParseLeadingTrivia($@" /// - /// Represents a Win32 handle that can be closed with . + /// Represents a Win32 handle that can be closed with . /// ")); this.volatileCode.AddSafeHandleType(safeHandleDeclaration); - if (this.GroupByModule) - { - this.volatileCode.AddMemberToModule(releaseMethodModule, safeHandleDeclaration); - } - return safeHandleType; } @@ -2667,8 +2651,7 @@ private void DeclareExternMethod(MethodDefinitionHandle methodDefinitionHandle) methodDeclaration = methodDeclaration.AddModifiers(TokenWithSpace(SyntaxKind.UnsafeKeyword)); } - NameSyntax declaringTypeName = ParseName(this.GroupByModule ? GetClassNameForModule(moduleName) : this.SingleClassName); - this.volatileCode.AddMemberToModule(moduleName, this.DeclareFriendlyOverloads(methodDefinition, methodDeclaration, declaringTypeName, FriendlyOverloadOf.ExternMethod)); + this.volatileCode.AddMemberToModule(moduleName, this.DeclareFriendlyOverloads(methodDefinition, methodDeclaration, this.methodsAndConstantsClassName, FriendlyOverloadOf.ExternMethod)); this.volatileCode.AddMemberToModule(moduleName, methodDeclaration); } catch (Exception ex) @@ -2847,7 +2830,7 @@ private FieldDeclarationSyntax DeclareConstant(FieldDefinitionHandle fieldDefHan private ClassDeclarationSyntax DeclareConstantDefiningClass() { - return ClassDeclaration(this.constantsClassName.Identifier) + return ClassDeclaration(this.methodsAndConstantsClassName.Identifier) .AddMembers(this.committedCode.Fields.ToArray()) .WithModifiers(TokenList(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.PartialKeyword))); } diff --git a/src/Microsoft.Windows.CsWin32/GeneratorOptions.cs b/src/Microsoft.Windows.CsWin32/GeneratorOptions.cs index 54749fdd..faeb7745 100644 --- a/src/Microsoft.Windows.CsWin32/GeneratorOptions.cs +++ b/src/Microsoft.Windows.CsWin32/GeneratorOptions.cs @@ -18,15 +18,10 @@ public record GeneratorOptions public bool WideCharOnly { get; init; } = true; /// - /// Gets the name of a single class under which all p/invoke methods are generated, regardless of imported module. Use for one class per imported module. + /// Gets the name of a single class under which all p/invoke methods and constants are generated, regardless of imported module. /// /// The default value is "PInvoke". - public string? MethodsClassName { get; init; } = "PInvoke"; - - /// - /// Gets the name of the single class under which all constants are generated. - /// - public string ConstantsClassName { get; init; } = "Constants"; + public string ClassName { get; init; } = "PInvoke"; /// /// Gets a value indicating whether to emit a single source file as opposed to types spread across many files. @@ -56,6 +51,10 @@ public record GeneratorOptions /// Thrown when some setting is invalid. public void Validate() { + if (string.IsNullOrWhiteSpace(this.ClassName)) + { + throw new InvalidOperationException("The ClassName property must not be null or empty."); + } } /// diff --git a/src/Microsoft.Windows.CsWin32/settings.schema.json b/src/Microsoft.Windows.CsWin32/settings.schema.json index 7a70ddfa..091fb2d3 100644 --- a/src/Microsoft.Windows.CsWin32/settings.schema.json +++ b/src/Microsoft.Windows.CsWin32/settings.schema.json @@ -42,16 +42,10 @@ "type": "boolean", "default": false }, - "methodsClassName": { - "description": "The name of a single class under which all p/invoke methods are generated, regardless of imported module. Use null for one class per imported module.", - "type": [ "string", "null" ], - "default": "PInvoke", - "pattern": "^\\w+$" - }, - "constantsClassName": { - "description": "The name of the single class under which all constants are generated.", + "className": { + "description": "The name of a single class under which all p/invoke methods and constants are generated, regardless of imported module.", "type": "string", - "default": "Constants", + "default": "PInvoke", "pattern": "^\\w+$" }, "public": { diff --git a/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs b/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs index 956de299..28b0aecd 100644 --- a/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs +++ b/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs @@ -1008,33 +1008,19 @@ internal static unsafe ref uint ItemRef(this ref win32.Graphics.DirectShow.MainA [Fact] public void NullMethodsClass() { - this.generator = this.CreateGenerator(new GeneratorOptions { MethodsClassName = null }); - Assert.True(this.generator.TryGenerate("GetTickCount", CancellationToken.None)); - this.CollectGeneratedCode(this.generator); - this.AssertNoDiagnostics(); - Assert.Single(this.FindGeneratedType("Kernel32")); - Assert.Empty(this.FindGeneratedType("PInvoke")); + Assert.Throws(() => this.CreateGenerator(new GeneratorOptions { ClassName = null! })); } [Fact] public void RenamedMethodsClass() { - this.generator = this.CreateGenerator(new GeneratorOptions { MethodsClassName = "MyPInvoke" }); + this.generator = this.CreateGenerator(new GeneratorOptions { ClassName = "MyPInvoke" }); Assert.True(this.generator.TryGenerate("GetTickCount", CancellationToken.None)); - this.CollectGeneratedCode(this.generator); - Assert.Single(this.FindGeneratedType("MyPInvoke")); - Assert.Empty(this.FindGeneratedType("PInvoke")); - } - - [Fact] - public void RenamedConstantsClass() - { - this.generator = this.CreateGenerator(new GeneratorOptions { ConstantsClassName = "MyConstants" }); Assert.True(this.generator.TryGenerate("CDB_REPORT_BITS", CancellationToken.None)); this.CollectGeneratedCode(this.generator); this.AssertNoDiagnostics(); - Assert.Single(this.FindGeneratedType("MyConstants")); - Assert.Empty(this.FindGeneratedType("Constants")); + Assert.NotEmpty(this.FindGeneratedType("MyPInvoke")); + Assert.Empty(this.FindGeneratedType("PInvoke")); } [Theory, PairwiseData] @@ -1059,7 +1045,7 @@ public void ProjectReferenceBetweenTwoGeneratingProjects(bool internalsVisibleTo CSharpSyntaxTree.ParseText($@"[assembly: System.Runtime.CompilerServices.InternalsVisibleToAttribute(""{this.compilation.AssemblyName}"")]", this.parseOptions)); } - using var referencedGenerator = this.CreateGenerator(new GeneratorOptions { MethodsClassName = "P1" }, referencedProject); + using var referencedGenerator = this.CreateGenerator(new GeneratorOptions { ClassName = "P1" }, referencedProject); Assert.True(referencedGenerator.TryGenerate("LockWorkStation", CancellationToken.None)); Assert.True(referencedGenerator.TryGenerate("CreateFile", CancellationToken.None)); referencedProject = this.AddGeneratedCode(referencedProject, referencedGenerator); @@ -1067,7 +1053,7 @@ public void ProjectReferenceBetweenTwoGeneratingProjects(bool internalsVisibleTo // Now produce more code in a referencing project that includes at least one of the same types as generated in the referenced project. this.compilation = this.compilation.AddReferences(referencedProject.ToMetadataReference()); - this.generator = this.CreateGenerator(new GeneratorOptions { MethodsClassName = "P2" }); + this.generator = this.CreateGenerator(new GeneratorOptions { ClassName = "P2" }); Assert.True(this.generator.TryGenerate("HidD_GetAttributes", CancellationToken.None)); this.CollectGeneratedCode(this.generator); this.AssertNoDiagnostics(); diff --git a/test/WinRTInteropTest/Program.cs b/test/WinRTInteropTest/Program.cs index a2f17540..a2af682e 100644 --- a/test/WinRTInteropTest/Program.cs +++ b/test/WinRTInteropTest/Program.cs @@ -8,7 +8,6 @@ namespace WinRTInteropTest using Windows.Win32.Foundation; using Windows.Win32.Graphics.Gdi; using Windows.Win32.UI.WindowsAndMessaging; - using static Windows.Win32.Constants; using static Windows.Win32.PInvoke; internal class Program