From c1f80f77e8b6630b9de5f826cfcefdb25267e873 Mon Sep 17 00:00:00 2001 From: Andrew Arnott Date: Fri, 28 Apr 2023 15:24:56 -0600 Subject: [PATCH] Recognize `MemorySizeAttribute` in metadata to improve friendly overloads Fixes #913 --- .../Generator.FriendlyOverloads.cs | 6 +++++ .../Generator.Invariants.cs | 1 + src/Microsoft.Windows.CsWin32/Generator.cs | 14 +++++++++++ test/GenerationSandbox.Tests/GeneratedForm.cs | 6 +++++ .../FriendlyOverloadTests.cs | 25 +++++++++++++++++++ 5 files changed, 52 insertions(+) create mode 100644 test/Microsoft.Windows.CsWin32.Tests/FriendlyOverloadTests.cs diff --git a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs index ee7ba9be..8ff72eed 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs @@ -238,6 +238,12 @@ private IEnumerable DeclareFriendlyOverloads(MethodDefi sizeParamIndex = nativeArrayInfo.CountParamIndex; sizeConst = nativeArrayInfo.CountConst; } + else if (externParam.Type is PointerTypeSyntax { ElementType: PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.ByteKeyword } } && this.FindInteropDecorativeAttribute(param.GetCustomAttributes(), MemorySizeAttribute) is CustomAttribute att2) + { + isArray = true; + MemorySize memorySize = DecodeMemorySizeAttribute(att2); + sizeParamIndex = memorySize.BytesParamIndex; + } IdentifierNameSyntax localName = IdentifierName(origName + "Local"); if (isArray) diff --git a/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs b/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs index 000ef045..6dae1ca9 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs @@ -7,6 +7,7 @@ public partial class Generator { internal const string InteropDecorationNamespace = "Windows.Win32.Foundation.Metadata"; internal const string NativeArrayInfoAttribute = "NativeArrayInfoAttribute"; + internal const string MemorySizeAttribute = "MemorySizeAttribute"; internal const string RAIIFreeAttribute = "RAIIFreeAttribute"; internal const string DoNotReleaseAttribute = "DoNotReleaseAttribute"; internal const string GlobalNamespacePrefix = "global::"; diff --git a/src/Microsoft.Windows.CsWin32/Generator.cs b/src/Microsoft.Windows.CsWin32/Generator.cs index 80b6a8f5..0c45b29b 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.cs @@ -1054,6 +1054,15 @@ private static NativeArrayInfo DecodeNativeArrayInfoAttribute(CustomAttribute na }; } + private static MemorySize DecodeMemorySizeAttribute(CustomAttribute memorySizeAttribute) + { + CustomAttributeValue args = memorySizeAttribute.DecodeValue(CustomAttributeTypeProvider.Instance); + return new MemorySize + { + BytesParamIndex = (short?)args.NamedArguments.FirstOrDefault(a => a.Name == "BytesParamIndex").Value, + }; + } + private bool TryGetRenamedMethod(string methodName, [NotNullWhen(true)] out string? newName) { if (this.WideCharOnly && IsWideFunction(methodName)) @@ -1421,6 +1430,11 @@ internal struct NativeArrayInfo internal int? CountConst { get; init; } } + internal struct MemorySize + { + internal short? BytesParamIndex { get; init; } + } + private class DirectiveTriviaRemover : CSharpSyntaxRewriter { internal static readonly DirectiveTriviaRemover Instance = new(); diff --git a/test/GenerationSandbox.Tests/GeneratedForm.cs b/test/GenerationSandbox.Tests/GeneratedForm.cs index 3b2a4851..2858e127 100644 --- a/test/GenerationSandbox.Tests/GeneratedForm.cs +++ b/test/GenerationSandbox.Tests/GeneratedForm.cs @@ -73,4 +73,10 @@ private static void PROCESS_BASIC_INFORMATION_PebBaseAddressIsPointer() PEB_unmanaged* p = null; info.PebBaseAddress = p; } + + private static void WriteFile() + { + uint written = 0; + PInvoke.WriteFile((SafeHandle?)null, new byte[2], &written, (NativeOverlapped*)null); + } } diff --git a/test/Microsoft.Windows.CsWin32.Tests/FriendlyOverloadTests.cs b/test/Microsoft.Windows.CsWin32.Tests/FriendlyOverloadTests.cs new file mode 100644 index 00000000..e42ae541 --- /dev/null +++ b/test/Microsoft.Windows.CsWin32.Tests/FriendlyOverloadTests.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +public class FriendlyOverloadTests : GeneratorTestBase +{ + public FriendlyOverloadTests(ITestOutputHelper logger) + : base(logger) + { + } + + [Fact] + public void WriteFile() + { + this.Generate("WriteFile"); + Assert.Contains(this.FindGeneratedMethod("WriteFile"), m => m.ParameterList.Parameters.Count == 4); + } + + private void Generate(string name) + { + this.generator = this.CreateGenerator(); + Assert.True(this.generator.TryGenerate(name, CancellationToken.None)); + this.CollectGeneratedCode(this.generator); + this.AssertNoDiagnostics(); + } +}