diff --git a/src/Microsoft.Windows.CsWin32/HandleTypeHandleInfo.cs b/src/Microsoft.Windows.CsWin32/HandleTypeHandleInfo.cs index ed661196..732e345f 100644 --- a/src/Microsoft.Windows.CsWin32/HandleTypeHandleInfo.cs +++ b/src/Microsoft.Windows.CsWin32/HandleTypeHandleInfo.cs @@ -76,6 +76,7 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs bool isInterface; bool isNonCOMConformingInterface; bool isManagedType = inputs.Generator?.IsManagedType(this) ?? false; + QualifiedTypeDefinitionHandle? qtdh = default; switch (this.Handle.Kind) { case HandleKind.TypeDefinition: @@ -85,6 +86,7 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs nameSyntax = inputs.QualifyNames ? GetNestingQualifiedName(inputs.Generator, this.reader, td, hasUnmanagedSuffix, isInterfaceNestedInStruct: false) : IdentifierName(this.reader.GetString(td.Name) + simpleNameSuffix); isInterface = (td.Attributes & TypeAttributes.Interface) == TypeAttributes.Interface; isNonCOMConformingInterface = isInterface && inputs.Generator?.IsNonCOMInterface(td) is true; + qtdh = inputs.Generator is not null ? new QualifiedTypeDefinitionHandle(inputs.Generator, (TypeDefinitionHandle)this.Handle) : default; break; case HandleKind.TypeReference: var trh = (TypeReferenceHandle)this.Handle; @@ -94,6 +96,14 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs nameSyntax = inputs.QualifyNames ? GetNestingQualifiedName(inputs, this.reader, tr, hasUnmanagedSuffix) : IdentifierName(this.reader.GetString(tr.Name) + simpleNameSuffix); isInterface = inputs.Generator?.IsInterface(trh) is true; isNonCOMConformingInterface = isInterface && inputs.Generator?.IsNonCOMInterface(trh) is true; + if (inputs.Generator is not null) + { + if (inputs.Generator.TryGetTypeDefHandle(this.Handle, out QualifiedTypeDefinitionHandle qtdhTmp)) + { + qtdh = qtdhTmp; + } + } + break; default: throw new NotSupportedException("Unrecognized handle type."); @@ -115,6 +125,10 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs return new TypeSyntaxAndMarshaling(bclType); } + MarshalAsAttribute? marshalAs = null; + bool isDelegate = this.IsDelegate(inputs, out QualifiedTypeDefinition delegateDefinition) + && (qtdh is null || !Generator.IsUntypedDelegate(qtdh.Value.Reader, qtdh.Value.Reader.GetTypeDefinition(qtdh.Value.DefinitionHandle))); + if (simpleName is "PWSTR" or "PSTR") { bool isConst = this.IsConstantField || MetadataUtilities.FindAttribute(this.reader, customAttributes, Generator.InteropDecorationNamespace, "ConstAttribute").HasValue; @@ -140,11 +154,11 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs return new TypeSyntaxAndMarshaling(IdentifierName(specialName)); } } - else if (TryMarshalAsObject(inputs, simpleName, out MarshalAsAttribute? marshalAs)) + else if (TryMarshalAsObject(inputs, simpleName, out marshalAs)) { return new TypeSyntaxAndMarshaling(PredefinedType(Token(SyntaxKind.ObjectKeyword)), marshalAs, null); } - else if (!inputs.AllowMarshaling && this.IsDelegate(inputs, out QualifiedTypeDefinition delegateDefinition) && inputs.Generator is object && !Generator.IsUntypedDelegate(delegateDefinition.Generator.Reader, delegateDefinition.Definition)) + else if (!inputs.AllowMarshaling && isDelegate && inputs.Generator is object && !Generator.IsUntypedDelegate(delegateDefinition.Generator.Reader, delegateDefinition.Definition)) { return new TypeSyntaxAndMarshaling(inputs.Generator.FunctionPointer(delegateDefinition)); } @@ -153,6 +167,11 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs this.RequestTypeGeneration(inputs.Generator, this.GetContext(inputs)); } + if (isDelegate) + { + marshalAs = new(UnmanagedType.FunctionPtr); + } + TypeSyntax syntax = isInterface && (!inputs.AllowMarshaling || isNonCOMConformingInterface) ? PointerType(nameSyntax) : nameSyntax; @@ -178,12 +197,12 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs marshalCookie = marshalCookie.Substring(Generator.GlobalNamespacePrefix.Length); } - return new TypeSyntaxAndMarshaling(syntax, new MarshalAsAttribute(UnmanagedType.CustomMarshaler) { MarshalCookie = marshalCookie, MarshalType = Generator.WinRTCustomMarshalerFullName }, null); + marshalAs = new MarshalAsAttribute(UnmanagedType.CustomMarshaler) { MarshalCookie = marshalCookie, MarshalType = Generator.WinRTCustomMarshalerFullName }; } } } - return new TypeSyntaxAndMarshaling(syntax); + return new TypeSyntaxAndMarshaling(syntax, marshalAs, null); } internal override bool? IsValueType(TypeSyntaxSettings inputs) diff --git a/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs b/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs index 66468eb0..c141cb40 100644 --- a/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs +++ b/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs @@ -419,4 +419,35 @@ public void COMInterfaceIIDInterfaceOnAppropriateTFMs( Assert.DoesNotContain(actual, predicate); } } + + [Fact] + public void FunctionPointersAsParameters() + { + this.GenerateApi("IContextCallback"); + MethodDeclarationSyntax method = this.FindGeneratedMethod("ContextCallback").Single(m => m.Parent is InterfaceDeclarationSyntax); + ParameterSyntax parameter = method.ParameterList.Parameters[0]; + Assert.Contains( + parameter.AttributeLists, + al => al.Attributes.Any(a => + a is + { + Name: IdentifierNameSyntax { Identifier.ValueText: "MarshalAs" }, + ArgumentList.Arguments: [{ Expression: MemberAccessExpressionSyntax { Name: IdentifierNameSyntax { Identifier.ValueText: "FunctionPtr" } } }], + })); + } + + [Fact] + public void NoFunctionPointerForFARPROC() + { + this.GenerateApi("GetProcAddress"); + MethodDeclarationSyntax method = this.FindGeneratedMethod("GetProcAddress").Single(m => m.Modifiers.Any(SyntaxKind.ExternKeyword)); + Assert.DoesNotContain( + method.AttributeLists, + al => al.Target is { Identifier.RawKind: (int)SyntaxKind.ReturnKeyword } && al.Attributes.Any(a => + a is + { + Name: IdentifierNameSyntax { Identifier.ValueText: "MarshalAs" }, + ArgumentList.Arguments: [{ Expression: MemberAccessExpressionSyntax { Name: IdentifierNameSyntax { Identifier.ValueText: "FunctionPtr" } } }], + })); + } }