diff --git a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs index ee7ba9be..3dd483d4 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs @@ -238,6 +238,14 @@ private IEnumerable DeclareFriendlyOverloads(MethodDefi sizeParamIndex = nativeArrayInfo.CountParamIndex; sizeConst = nativeArrayInfo.CountConst; } + else if (externParam.Type is PointerTypeSyntax { ElementType: PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.ByteKeyword } } && this.FindInteropDecorativeAttribute(param.GetCustomAttributes(), MemorySizeAttribute) is CustomAttribute att2) + { + // A very special case as documented in https://github.com/microsoft/win32metadata/issues/1555 + // where MemorySizeAttribute is applied to byte* parameters to indicate the size of the buffer. + isArray = true; + MemorySize memorySize = DecodeMemorySizeAttribute(att2); + sizeParamIndex = memorySize.BytesParamIndex; + } IdentifierNameSyntax localName = IdentifierName(origName + "Local"); if (isArray) diff --git a/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs b/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs index 000ef045..6dae1ca9 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs @@ -7,6 +7,7 @@ public partial class Generator { internal const string InteropDecorationNamespace = "Windows.Win32.Foundation.Metadata"; internal const string NativeArrayInfoAttribute = "NativeArrayInfoAttribute"; + internal const string MemorySizeAttribute = "MemorySizeAttribute"; internal const string RAIIFreeAttribute = "RAIIFreeAttribute"; internal const string DoNotReleaseAttribute = "DoNotReleaseAttribute"; internal const string GlobalNamespacePrefix = "global::"; diff --git a/src/Microsoft.Windows.CsWin32/Generator.cs b/src/Microsoft.Windows.CsWin32/Generator.cs index 80b6a8f5..0c45b29b 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.cs @@ -1054,6 +1054,15 @@ private static NativeArrayInfo DecodeNativeArrayInfoAttribute(CustomAttribute na }; } + private static MemorySize DecodeMemorySizeAttribute(CustomAttribute memorySizeAttribute) + { + CustomAttributeValue args = memorySizeAttribute.DecodeValue(CustomAttributeTypeProvider.Instance); + return new MemorySize + { + BytesParamIndex = (short?)args.NamedArguments.FirstOrDefault(a => a.Name == "BytesParamIndex").Value, + }; + } + private bool TryGetRenamedMethod(string methodName, [NotNullWhen(true)] out string? newName) { if (this.WideCharOnly && IsWideFunction(methodName)) @@ -1421,6 +1430,11 @@ internal struct NativeArrayInfo internal int? CountConst { get; init; } } + internal struct MemorySize + { + internal short? BytesParamIndex { get; init; } + } + private class DirectiveTriviaRemover : CSharpSyntaxRewriter { internal static readonly DirectiveTriviaRemover Instance = new(); diff --git a/test/GenerationSandbox.Tests/GeneratedForm.cs b/test/GenerationSandbox.Tests/GeneratedForm.cs index 3b2a4851..2858e127 100644 --- a/test/GenerationSandbox.Tests/GeneratedForm.cs +++ b/test/GenerationSandbox.Tests/GeneratedForm.cs @@ -73,4 +73,10 @@ private static void PROCESS_BASIC_INFORMATION_PebBaseAddressIsPointer() PEB_unmanaged* p = null; info.PebBaseAddress = p; } + + private static void WriteFile() + { + uint written = 0; + PInvoke.WriteFile((SafeHandle?)null, new byte[2], &written, (NativeOverlapped*)null); + } } diff --git a/test/Microsoft.Windows.CsWin32.Tests/FriendlyOverloadTests.cs b/test/Microsoft.Windows.CsWin32.Tests/FriendlyOverloadTests.cs new file mode 100644 index 00000000..1e44695d --- /dev/null +++ b/test/Microsoft.Windows.CsWin32.Tests/FriendlyOverloadTests.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +public class FriendlyOverloadTests : GeneratorTestBase +{ + public FriendlyOverloadTests(ITestOutputHelper logger) + : base(logger) + { + } + + [Fact] + public void WriteFile() + { + const string name = "WriteFile"; + this.Generate(name); + Assert.Contains(this.FindGeneratedMethod(name), m => m.ParameterList.Parameters.Count == 4); + } + + [Fact] + public void SHGetFileInfo() + { + // This method uses MemorySize but for determining the size of a struct that another parameter points to. + // We cannot know the size of that, since it may be a v1 struct, a v2 struct, etc. + // So assert that no overload has fewer parameters. + const string name = "SHGetFileInfo"; + this.Generate(name); + Assert.All(this.FindGeneratedMethod(name), m => Assert.Equal(5, m.ParameterList.Parameters.Count)); + } + + private void Generate(string name) + { + this.compilation = this.compilation.WithOptions(this.compilation.Options.WithPlatform(Platform.X64)); + this.generator = this.CreateGenerator(); + Assert.True(this.generator.TryGenerate(name, CancellationToken.None)); + this.CollectGeneratedCode(this.generator); + this.AssertNoDiagnostics(); + } +}