diff --git a/src/Microsoft.Windows.CsWin32/Generator.cs b/src/Microsoft.Windows.CsWin32/Generator.cs index 35914c16..6d23c374 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.cs @@ -357,6 +357,8 @@ public class Generator : IDisposable AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(ThisAssembly.AssemblyName))), AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(ThisAssembly.AssemblyInformationalVersion))))); + private static readonly TypeSyntax HresultTypeSyntax = IdentifierName("HRESULT"); + private readonly TypeSyntaxSettings generalTypeSettings; private readonly TypeSyntaxSettings fieldTypeSettings; private readonly TypeSyntaxSettings delegateSignatureTypeSettings; @@ -2617,6 +2619,8 @@ private static Guid DecodeGuidFromAttribute(CustomAttribute guidAttribute) (byte)args.FixedArguments[10].Value!); } + private static bool IsHresult(TypeHandleInfo? typeHandleInfo) => typeHandleInfo is HandleTypeHandleInfo handleInfo && handleInfo.IsType("HRESULT"); + private T AddApiDocumentation(string api, T memberDeclaration) where T : MemberDeclarationSyntax { @@ -3453,6 +3457,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type allMethods.AddRange(typeDef.GetMethods().Select(m => new QualifiedMethodDefinitionHandle(this, m))); int methodCounter = 0; HashSet helperMethodsInStruct = new(); + HashSet declaredProperties = new(StringComparer.Ordinal); foreach (QualifiedMethodDefinitionHandle methodDefHandle in allMethods) { methodCounter++; @@ -3492,43 +3497,127 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type .WithArgumentList(FixTrivia(ArgumentList() .AddArguments(Argument(pThisLocal)) .AddArguments(parameterList.Parameters.Select(p => Argument(IdentifierName(p.Identifier.ValueText)).WithRefKindKeyword(p.Modifiers.Count > 0 ? p.Modifiers[0] : default)).ToArray()))); - StatementSyntax vtblInvocationStatement = IsVoid(returnType.Type) - ? ExpressionStatement(vtblInvocation) - : ReturnStatement(vtblInvocation); - BlockSyntax? body = Block().AddStatements( - FixedStatement( - VariableDeclaration(PointerType(ifaceName)).AddVariables( - VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))), - vtblInvocationStatement).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword))); - MethodDeclarationSyntax methodDeclaration = MethodDeclaration( - List(), - modifiers: TokenList(TokenWithSpace(SyntaxKind.PublicKeyword)), // always use public so struct can implement the COM interface - returnType.Type.WithTrailingTrivia(TriviaList(Space)), - explicitInterfaceSpecifier: null!, - SafeIdentifier(methodName), - null!, - parameterList, - List(), - body: body, - semicolonToken: default); - methodDeclaration = returnType.AddReturnMarshalAs(methodDeclaration); + MemberDeclarationSyntax propertyOrMethod; + MethodDeclarationSyntax? methodDeclaration = null; + int priorPropertyDeclarationIndex = -1; - if (methodName == nameof(object.GetType) && parameterList.Parameters.Count == 0) + // We can declare this method as a property accessor if it represents a property. + // We must also confirm that the property type is the same in both cases, because sometimes they aren't (e.g. IUIAutomationProxyFactoryEntry.ClassName). + if (this.TryGetPropertyAccessorInfo(methodDefinition.Method, out IdentifierNameSyntax? propertyName, out SyntaxKind? accessorKind, out TypeSyntax? propertyType) && + (declaredProperties.Add(propertyName.Identifier.ValueText) || + ((priorPropertyDeclarationIndex = members.FindIndex(m => m is PropertyDeclarationSyntax prop && prop.Identifier.ValueText == propertyName.Identifier.ValueText)) >= 0 + && ((PropertyDeclarationSyntax)members[priorPropertyDeclarationIndex]).Type.ToString() == propertyType.ToString()))) { - methodDeclaration = methodDeclaration.AddModifiers(TokenWithSpace(SyntaxKind.NewKeyword)); - } + StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionStatement(InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrExpression, IdentifierName("ThrowOnFailure")), + ArgumentList())); + + BlockSyntax? body; + switch (accessorKind) + { + case SyntaxKind.GetAccessorDeclaration: + // PropertyType __result; + IdentifierNameSyntax resultLocal = IdentifierName("__result"); + LocalDeclarationStatementSyntax resultLocalDeclaration = LocalDeclarationStatement(VariableDeclaration(propertyType).AddVariables(VariableDeclarator(resultLocal.Identifier))); + + // vtblInvoke(pThis, &__result).ThrowOnFailure(); + // vtblInvoke(pThis, out __result).ThrowOnFailure(); + ArgumentSyntax resultArgument = funcPtrParameters.Parameters[1].Modifiers.Any(SyntaxKind.OutKeyword) + ? Argument(resultLocal).WithRefKindKeyword(Token(SyntaxKind.OutKeyword)) + : Argument(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, resultLocal)); + StatementSyntax vtblInvocationStatement = ThrowOnHRFailure(vtblInvocation.WithArgumentList(ArgumentList().AddArguments(Argument(pThisLocal), resultArgument))); + + // return __result; + StatementSyntax returnStatement = ReturnStatement(resultLocal); + + body = Block().AddStatements( + FixedStatement( + VariableDeclaration(PointerType(ifaceName)).AddVariables( + VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))), + Block().AddStatements( + resultLocalDeclaration, + vtblInvocationStatement, + returnStatement)).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword))); + break; + case SyntaxKind.SetAccessorDeclaration: + // vtblInvoke(pThis, value).ThrowOnFailure(); + vtblInvocationStatement = ThrowOnHRFailure(vtblInvocation.WithArgumentList(ArgumentList().AddArguments(Argument(pThisLocal), Argument(IdentifierName("value"))))); + body = Block().AddStatements( + FixedStatement( + VariableDeclaration(PointerType(ifaceName)).AddVariables( + VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))), + vtblInvocationStatement).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword))); + break; + default: + throw new NotSupportedException("Unsupported accessor kind: " + accessorKind); + } + + AccessorDeclarationSyntax accessor = AccessorDeclaration(accessorKind.Value, body); + + if (priorPropertyDeclarationIndex >= 0) + { + // Add the accessor to the existing property declaration. + PropertyDeclarationSyntax priorDeclaration = (PropertyDeclarationSyntax)members[priorPropertyDeclarationIndex]; + members[priorPropertyDeclarationIndex] = priorDeclaration.WithAccessorList(priorDeclaration.AccessorList!.AddAccessors(accessor)); + continue; + } + else + { + PropertyDeclarationSyntax propertyDeclaration = PropertyDeclaration(propertyType.WithTrailingTrivia(Space), propertyName.Identifier.WithTrailingTrivia(LineFeed)); - if (methodDeclaration.ReturnType is PointerTypeSyntax || methodDeclaration.ParameterList.Parameters.Any(p => p.Type is PointerTypeSyntax)) + propertyDeclaration = propertyDeclaration.WithAccessorList(AccessorList().AddAccessors(accessor)); + + if (propertyDeclaration.Type is PointerTypeSyntax) + { + propertyDeclaration = propertyDeclaration.AddModifiers(TokenWithSpace(SyntaxKind.UnsafeKeyword)); + } + + propertyOrMethod = propertyDeclaration; + } + } + else { - methodDeclaration = methodDeclaration.AddModifiers(TokenWithSpace(SyntaxKind.UnsafeKeyword)); + StatementSyntax vtblInvocationStatement = IsVoid(returnType.Type) + ? ExpressionStatement(vtblInvocation) + : ReturnStatement(vtblInvocation); + BlockSyntax? body = Block().AddStatements( + FixedStatement( + VariableDeclaration(PointerType(ifaceName)).AddVariables( + VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))), + vtblInvocationStatement).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword))); + + methodDeclaration = MethodDeclaration( + List(), + modifiers: TokenList(TokenWithSpace(SyntaxKind.PublicKeyword)), // always use public so struct can implement the COM interface + returnType.Type.WithTrailingTrivia(TriviaList(Space)), + explicitInterfaceSpecifier: null!, + SafeIdentifier(methodName), + null!, + parameterList, + List(), + body: body, + semicolonToken: default); + methodDeclaration = returnType.AddReturnMarshalAs(methodDeclaration); + + if (methodName == nameof(object.GetType) && parameterList.Parameters.Count == 0) + { + methodDeclaration = methodDeclaration.AddModifiers(TokenWithSpace(SyntaxKind.NewKeyword)); + } + + if (methodDeclaration.ReturnType is PointerTypeSyntax || methodDeclaration.ParameterList.Parameters.Any(p => p.Type is PointerTypeSyntax)) + { + methodDeclaration = methodDeclaration.AddModifiers(TokenWithSpace(SyntaxKind.UnsafeKeyword)); + } + + propertyOrMethod = methodDeclaration; + + members.AddRange(methodDefinition.Generator.DeclareFriendlyOverloads(methodDefinition.Method, methodDeclaration, IdentifierName(ifaceName.Identifier.ValueText), FriendlyOverloadOf.StructMethod, helperMethodsInStruct)); } // Add documentation if we can find it. - methodDeclaration = this.AddApiDocumentation($"{ifaceName}.{methodName}", methodDeclaration); - - members.AddRange(methodDefinition.Generator.DeclareFriendlyOverloads(methodDefinition.Method, methodDeclaration, IdentifierName(ifaceName.Identifier.ValueText), FriendlyOverloadOf.StructMethod, helperMethodsInStruct)); - members.Add(methodDeclaration); + propertyOrMethod = this.AddApiDocumentation($"{ifaceName}.{methodName}", propertyOrMethod); + members.Add(propertyOrMethod); } // We expose the vtbl struct, not because we expect folks to use it directly, but because some folks may use it to manually generate CCWs. @@ -3544,11 +3633,11 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type Guid? guidAttributeValue = guidAttribute.HasValue ? DecodeGuidFromAttribute(guidAttribute.Value) : null; if (guidAttribute.HasValue) { - // internal static readonly Guid Guid = new Guid(0x1234, ...); + // internal static readonly Guid IID_Guid = new Guid(0x1234, ...); TypeSyntax guidTypeSyntax = IdentifierName(nameof(Guid)); members.Add(FieldDeclaration( VariableDeclaration(guidTypeSyntax) - .AddVariables(VariableDeclarator(Identifier("Guid")).WithInitializer(EqualsValueClause( + .AddVariables(VariableDeclarator(Identifier("IID_Guid")).WithInitializer(EqualsValueClause( GuidValue(guidAttribute.Value))))) .AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.ReadOnlyKeyword)) .WithLeadingTrivia(ParseLeadingTrivia($"/// The IID guid for this interface.\n/// {guidAttributeValue!.Value:B}\n"))); @@ -3642,77 +3731,118 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type var members = new List(); var friendlyOverloads = new List(); + HashSet declaredProperties = new(StringComparer.Ordinal); foreach (MethodDefinitionHandle methodDefHandle in allMethods) { MethodDefinition methodDefinition = this.Reader.GetMethodDefinition(methodDefHandle); string methodName = this.Reader.GetString(methodDefinition.Name); + inheritedMethods--; try { - IdentifierNameSyntax innerMethodName = IdentifierName(methodName); - MethodSignature signature = methodDefinition.DecodeSignature(SignatureHandleProvider.Instance, null); + MemberDeclarationSyntax propertyOrMethod; + MethodDeclarationSyntax? methodDeclaration = null; + + // Consider whether we should declare this as a property. + // Even if it could be represented as a property accessor, we cannot do so if a property by the same name was already declared in anything other than the previous row. + // Adding an accessor to a property later than the very next row would screw up the virtual method table ordering. + // We must also confirm that the property type is the same in both cases, because sometimes they aren't (e.g. IUIAutomationProxyFactoryEntry.ClassName). + PropertyDeclarationSyntax? PriorRowIfSameProperty(IdentifierNameSyntax propertyName, TypeSyntax propertyType) => members.Count > 0 && members[members.Count - 1] is PropertyDeclarationSyntax lastProperty && lastProperty.Identifier.ValueText == propertyName.Identifier.ValueText && propertyType.ToString() == lastProperty.Type?.ToString() ? lastProperty : null; + if (this.TryGetPropertyAccessorInfo(methodDefinition, out IdentifierNameSyntax? propertyName, out SyntaxKind? accessorKind, out TypeSyntax? propertyType) && (declaredProperties.Add(propertyName.Identifier.ValueText) || PriorRowIfSameProperty(propertyName, propertyType) is not null)) + { + AccessorDeclarationSyntax accessor = AccessorDeclaration(accessorKind.Value).WithSemicolonToken(Semicolon); - CustomAttributeHandleCollection? returnTypeAttributes = this.GetReturnTypeCustomAttributes(methodDefinition); - TypeSyntaxAndMarshaling returnTypeDetails = signature.ReturnType.ToTypeSyntax(typeSettings, returnTypeAttributes); - TypeSyntax returnType = returnTypeDetails.Type; - AttributeSyntax? returnsAttribute = MarshalAs(returnTypeDetails.MarshalAsAttribute, returnTypeDetails.NativeArrayInfo); + if (PriorRowIfSameProperty(propertyName, propertyType) is { } lastProperty) + { + // Add the accessor to the existing property declaration. + members[members.Count - 1] = lastProperty.WithAccessorList(lastProperty.AccessorList!.AddAccessors(accessor)); + continue; + } + else + { + PropertyDeclarationSyntax propertyDeclaration = PropertyDeclaration(propertyType.WithTrailingTrivia(Space), propertyName.Identifier.WithTrailingTrivia(LineFeed)); - bool preserveSig = interfaceAsSubtype - || returnType is not QualifiedNameSyntax { Right: { Identifier: { ValueText: "HRESULT" } } } - || (methodDefinition.ImplAttributes & MethodImplAttributes.PreserveSig) == MethodImplAttributes.PreserveSig - || this.options.ComInterop.PreserveSigMethods.Contains($"{ifaceName}.{methodName}") - || this.options.ComInterop.PreserveSigMethods.Contains(ifaceName.ToString()); + propertyDeclaration = propertyDeclaration.WithAccessorList(AccessorList().AddAccessors(accessor)); - ParameterListSyntax? parameterList = this.CreateParameterList(methodDefinition, signature, this.comSignatureTypeSettings); + if (propertyDeclaration.Type is PointerTypeSyntax) + { + propertyDeclaration = propertyDeclaration.AddModifiers(TokenWithSpace(SyntaxKind.UnsafeKeyword)); + } - if (!preserveSig) + propertyOrMethod = propertyDeclaration; + } + } + else { - ParameterSyntax? lastParameter = parameterList.Parameters.Count > 0 ? parameterList.Parameters[parameterList.Parameters.Count - 1] : null; - if (lastParameter?.HasAnnotation(IsRetValAnnotation) is true) + MethodSignature signature = methodDefinition.DecodeSignature(SignatureHandleProvider.Instance, null); + + CustomAttributeHandleCollection? returnTypeAttributes = this.GetReturnTypeCustomAttributes(methodDefinition); + TypeSyntaxAndMarshaling returnTypeDetails = signature.ReturnType.ToTypeSyntax(typeSettings, returnTypeAttributes); + TypeSyntax returnType = returnTypeDetails.Type; + AttributeSyntax? returnsAttribute = MarshalAs(returnTypeDetails.MarshalAsAttribute, returnTypeDetails.NativeArrayInfo); + + ParameterListSyntax? parameterList = this.CreateParameterList(methodDefinition, signature, this.comSignatureTypeSettings); + + bool preserveSig = interfaceAsSubtype + || !IsHresult(signature.ReturnType) + || (methodDefinition.ImplAttributes & MethodImplAttributes.PreserveSig) == MethodImplAttributes.PreserveSig + || this.options.ComInterop.PreserveSigMethods.Contains($"{ifaceName}.{methodName}") + || this.options.ComInterop.PreserveSigMethods.Contains(ifaceName.ToString()); + + if (!preserveSig) { - // Move the retval parameter to the return value position. - parameterList = parameterList.WithParameters(parameterList.Parameters.RemoveAt(parameterList.Parameters.Count - 1)); - returnType = lastParameter.Modifiers.Any(SyntaxKind.OutKeyword) ? lastParameter.Type! : ((PointerTypeSyntax)lastParameter.Type!).ElementType; - returnsAttribute = lastParameter.DescendantNodes().OfType().FirstOrDefault(att => att.Name.ToString() == "MarshalAs"); + ParameterSyntax? lastParameter = parameterList.Parameters.Count > 0 ? parameterList.Parameters[parameterList.Parameters.Count - 1] : null; + if (lastParameter?.HasAnnotation(IsRetValAnnotation) is true) + { + // Move the retval parameter to the return value position. + parameterList = parameterList.WithParameters(parameterList.Parameters.RemoveAt(parameterList.Parameters.Count - 1)); + returnType = lastParameter.Modifiers.Any(SyntaxKind.OutKeyword) ? lastParameter.Type! : ((PointerTypeSyntax)lastParameter.Type!).ElementType; + returnsAttribute = lastParameter.DescendantNodes().OfType().FirstOrDefault(att => att.Name.ToString() == "MarshalAs"); + } + else + { + // Remove the return type + returnType = PredefinedType(Token(SyntaxKind.VoidKeyword)); + } } - else + + methodDeclaration = MethodDeclaration(returnType.WithTrailingTrivia(TriviaList(Space)), SafeIdentifier(methodName)) + .WithParameterList(FixTrivia(parameterList)) + .WithSemicolonToken(SemicolonWithLineFeed); + if (returnsAttribute is object) { - // Remove the return type - returnType = PredefinedType(Token(SyntaxKind.VoidKeyword)); + methodDeclaration = methodDeclaration.AddAttributeLists( + AttributeList().WithTarget(AttributeTargetSpecifier(Token(SyntaxKind.ReturnKeyword))).AddAttributes(returnsAttribute)); } - } - MethodDeclarationSyntax methodDeclaration = MethodDeclaration(returnType.WithTrailingTrivia(TriviaList(Space)), SafeIdentifier(methodName)) - .WithParameterList(FixTrivia(parameterList)) - .WithSemicolonToken(SemicolonWithLineFeed); - if (returnsAttribute is object) - { - methodDeclaration = methodDeclaration.AddAttributeLists( - AttributeList().WithTarget(AttributeTargetSpecifier(Token(SyntaxKind.ReturnKeyword))).AddAttributes(returnsAttribute)); - } + if (preserveSig) + { + methodDeclaration = methodDeclaration.AddAttributeLists(AttributeList().AddAttributes(PreserveSigAttribute)); + } - if (preserveSig) - { - methodDeclaration = methodDeclaration.AddAttributeLists(AttributeList().AddAttributes(PreserveSigAttribute)); - } + if (methodDeclaration.ReturnType is PointerTypeSyntax || methodDeclaration.ParameterList.Parameters.Any(p => p.Type is PointerTypeSyntax)) + { + methodDeclaration = methodDeclaration.AddModifiers(TokenWithSpace(SyntaxKind.UnsafeKeyword)); + } - if (inheritedMethods-- > 0) - { - methodDeclaration = methodDeclaration.AddModifiers(TokenWithSpace(SyntaxKind.NewKeyword)); + propertyOrMethod = methodDeclaration; } - if (methodDeclaration.ReturnType is PointerTypeSyntax || methodDeclaration.ParameterList.Parameters.Any(p => p.Type is PointerTypeSyntax)) + if (inheritedMethods >= 0) { - methodDeclaration = methodDeclaration.AddModifiers(TokenWithSpace(SyntaxKind.UnsafeKeyword)); + propertyOrMethod = propertyOrMethod.AddModifiers(TokenWithSpace(SyntaxKind.NewKeyword)); } // Add documentation if we can find it. - methodDeclaration = this.AddApiDocumentation($"{ifaceName}.{methodName}", methodDeclaration); - members.Add(methodDeclaration); + propertyOrMethod = this.AddApiDocumentation($"{ifaceName}.{methodName}", propertyOrMethod); + members.Add(propertyOrMethod); - NameSyntax declaringTypeName = HandleTypeHandleInfo.GetNestingQualifiedName(this, this.Reader, typeDef, hasUnmanagedSuffix: false, isInterfaceNestedInStruct: interfaceAsSubtype); - friendlyOverloads.AddRange( - this.DeclareFriendlyOverloads(methodDefinition, methodDeclaration, declaringTypeName, FriendlyOverloadOf.InterfaceMethod, this.injectedPInvokeHelperMethodsToFriendlyOverloadsExtensions)); + if (methodDeclaration is not null) + { + NameSyntax declaringTypeName = HandleTypeHandleInfo.GetNestingQualifiedName(this, this.Reader, typeDef, hasUnmanagedSuffix: false, isInterfaceNestedInStruct: interfaceAsSubtype); + friendlyOverloads.AddRange( + this.DeclareFriendlyOverloads(methodDefinition, methodDeclaration, declaringTypeName, FriendlyOverloadOf.InterfaceMethod, this.injectedPInvokeHelperMethodsToFriendlyOverloadsExtensions)); + } } catch (Exception ex) { @@ -3748,6 +3878,73 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type return ifaceDeclaration; } + private bool TryGetPropertyAccessorInfo(MethodDefinition methodDefinition, [NotNullWhen(true)] out IdentifierNameSyntax? propertyName, [NotNullWhen(true)] out SyntaxKind? accessorKind, [NotNullWhen(true)] out TypeSyntax? propertyType) + { + propertyName = null; + accessorKind = null; + propertyType = null; + if ((methodDefinition.Attributes & MethodAttributes.SpecialName) != MethodAttributes.SpecialName) + { + return false; + } + + if ((methodDefinition.ImplAttributes & MethodImplAttributes.PreserveSig) == MethodImplAttributes.PreserveSig) + { + return false; + } + + ParameterHandleCollection parameters = methodDefinition.GetParameters(); + if (parameters.Count != 2) + { + return false; + } + + string methodName = this.Reader.GetString(methodDefinition.Name); + const string getterPrefix = "get_"; + const string setterPrefix = "put_"; + bool isGetter = methodName.StartsWith(getterPrefix, StringComparison.Ordinal); + bool isSetter = methodName.StartsWith(setterPrefix, StringComparison.Ordinal); + + if (isGetter || isSetter) + { + MethodSignature signature = methodDefinition.DecodeSignature(SignatureHandleProvider.Instance, null); + if (!IsHresult(signature.ReturnType)) + { + return false; + } + + Parameter propertyTypeParameter = this.Reader.GetParameter(parameters.Skip(1).Single()); + propertyType = signature.ParameterTypes[0].ToTypeSyntax(this.comSignatureTypeSettings, propertyTypeParameter.GetCustomAttributes(), propertyTypeParameter.Attributes).Type; + + if (isGetter) + { + propertyName = SafeIdentifierName(methodName.Substring(getterPrefix.Length)); + accessorKind = SyntaxKind.GetAccessorDeclaration; + + if ((propertyTypeParameter.Attributes & ParameterAttributes.Out) != ParameterAttributes.Out) + { + return false; + } + + if (propertyType is PointerTypeSyntax propertyTypePointer) + { + propertyType = propertyTypePointer.ElementType; + } + + return true; + } + + if (isSetter) + { + propertyName = SafeIdentifierName(methodName.Substring(setterPrefix.Length)); + accessorKind = SyntaxKind.SetAccessorDeclaration; + return true; + } + } + + return false; + } + private CustomAttribute? FindGuidAttribute(CustomAttributeHandleCollection attributes) => this.FindInteropDecorativeAttribute(attributes, nameof(GuidAttribute)); private Guid? FindGuidFromAttribute(TypeDefinition typeDef) => this.FindGuidFromAttribute(typeDef.GetCustomAttributes()); @@ -6880,7 +7077,15 @@ internal WhitespaceRewriter() return base.VisitAccessorList(node); } - public override SyntaxNode? VisitAccessorDeclaration(AccessorDeclarationSyntax node) => base.VisitAccessorDeclaration(this.WithIndentingTrivia(node)); + public override SyntaxNode? VisitAccessorDeclaration(AccessorDeclarationSyntax node) + { + if (node.Body is not null) + { + node = node.WithKeyword(node.Keyword.WithTrailingTrivia(LineFeed)); + } + + return base.VisitAccessorDeclaration(this.WithIndentingTrivia(node)); + } public override SyntaxNode? VisitLocalDeclarationStatement(LocalDeclarationStatementSyntax node) => base.VisitLocalDeclarationStatement(this.WithIndentingTrivia(node)); diff --git a/src/Microsoft.Windows.CsWin32/HandleTypeHandleInfo.cs b/src/Microsoft.Windows.CsWin32/HandleTypeHandleInfo.cs index 018c61a1..3e76cefc 100644 --- a/src/Microsoft.Windows.CsWin32/HandleTypeHandleInfo.cs +++ b/src/Microsoft.Windows.CsWin32/HandleTypeHandleInfo.cs @@ -33,6 +33,22 @@ internal HandleTypeHandleInfo(MetadataReader reader, EntityHandle handle, byte? public override string ToString() => this.ToTypeSyntaxForDisplay().ToString(); + internal bool IsType(string leafName) + { + switch (this.Handle.Kind) + { + case HandleKind.TypeDefinition: + TypeDefinition td = this.reader.GetTypeDefinition((TypeDefinitionHandle)this.Handle); + return this.reader.StringComparer.Equals(td.Name, leafName); + case HandleKind.TypeReference: + var trh = (TypeReferenceHandle)this.Handle; + TypeReference tr = this.reader.GetTypeReference(trh); + return this.reader.StringComparer.Equals(tr.Name, leafName); + default: + throw new NotSupportedException("Unrecognized handle type."); + } + } + internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs, CustomAttributeHandleCollection? customAttributes, ParameterAttributes parameterAttributes = default) { NameSyntax? nameSyntax; diff --git a/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs b/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs index 7ace631c..4ffafabe 100644 --- a/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs +++ b/test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs @@ -298,6 +298,7 @@ public void InterestingAPIs( "PCZZSTR", "PCZZWSTR", "IUIAutomation", // non-preservesig retval COM method with a array size index parameter + "IHTMLWindow2", // contains properties named with C# reserved keywords "CreateFile", // built-in SafeHandle use "CreateCursor", // 0 or -1 invalid SafeHandle generated "PlaySound", // 0 invalid SafeHandle generated @@ -526,6 +527,114 @@ public void IInpectableDerivedInterface() Assert.Empty(this.FindGeneratedType(WinRTCustomMarshalerClass)); } + [Theory, PairwiseData] + public void COMPropertiesAreGeneratedAsInterfaceProperties(bool allowMarshaling) + { + const string ifaceName = "IADsClass"; + this.generator = this.CreateGenerator(DefaultTestGeneratorOptions with { AllowMarshaling = allowMarshaling }); + Assert.True(this.generator.TryGenerate(ifaceName, CancellationToken.None)); + this.CollectGeneratedCode(this.generator); + this.AssertNoDiagnostics(); + InterfaceDeclarationSyntax ifaceSyntax; + if (allowMarshaling) + { + ifaceSyntax = Assert.Single(this.FindGeneratedType(ifaceName).OfType()); + } + else + { + StructDeclarationSyntax structSyntax = (StructDeclarationSyntax)Assert.Single(this.FindGeneratedType(ifaceName)); + ifaceSyntax = Assert.Single(structSyntax.Members.OfType(), m => m.Identifier.ValueText == "Interface"); + } + + // Check a property where we expect just a getter. + PropertyDeclarationSyntax getProperty = Assert.Single(ifaceSyntax.Members.OfType(), m => m.Identifier.ValueText == "PrimaryInterface"); + Assert.True(HasAccessor(getProperty, SyntaxKind.GetAccessorDeclaration)); + Assert.False(HasAccessor(getProperty, SyntaxKind.SetAccessorDeclaration)); + + // Check a property where we expect both a getter and setter. + PropertyDeclarationSyntax getSetProperty = Assert.Single(ifaceSyntax.Members.OfType(), m => m.Identifier.ValueText == "CLSID"); + Assert.True(HasAccessor(getSetProperty, SyntaxKind.GetAccessorDeclaration)); + Assert.True(HasAccessor(getSetProperty, SyntaxKind.SetAccessorDeclaration)); + + bool HasAccessor(PropertyDeclarationSyntax property, SyntaxKind kind) => property.AccessorList?.Accessors.SingleOrDefault(a => a.IsKind(kind)) is not null; + } + + [Theory, PairwiseData] + public void COMPropertiesAreGeneratedAsInterfaceProperties_NonConsecutiveAccessors(bool allowMarshaling) + { + const string ifaceName = "IUIAutomationProxyFactoryEntry"; + this.generator = this.CreateGenerator(DefaultTestGeneratorOptions with { AllowMarshaling = allowMarshaling }); + Assert.True(this.generator.TryGenerate(ifaceName, CancellationToken.None)); + this.CollectGeneratedCode(this.generator); + this.AssertNoDiagnostics(); + InterfaceDeclarationSyntax ifaceSyntax; + if (allowMarshaling) + { + ifaceSyntax = Assert.Single(this.FindGeneratedType(ifaceName).OfType()); + } + else + { + StructDeclarationSyntax structSyntax = (StructDeclarationSyntax)Assert.Single(this.FindGeneratedType(ifaceName)); + ifaceSyntax = Assert.Single(structSyntax.Members.OfType(), m => m.Identifier.ValueText == "Interface"); + } + + // Check for a property where the interface declares the getter and setter in non-consecutive rows of the VMT. + // In such a case, at most only one accessor can be declared, and the other must be a method. + PropertyDeclarationSyntax getSetProperty = Assert.Single(ifaceSyntax.Members.OfType(), m => m.Identifier.ValueText == "ClassName"); + Assert.True(HasAccessor(getSetProperty, SyntaxKind.GetAccessorDeclaration)); + Assert.False(HasAccessor(getSetProperty, SyntaxKind.SetAccessorDeclaration)); + + bool HasAccessor(PropertyDeclarationSyntax property, SyntaxKind kind) => property.AccessorList?.Accessors.SingleOrDefault(a => a.IsKind(kind)) is not null; + } + + [Fact] + public void COMPropertiesAreGeneratedAsStructProperties() + { + const string ifaceName = "IADsClass"; + this.generator = this.CreateGenerator(DefaultTestGeneratorOptions with { AllowMarshaling = false }); + Assert.True(this.generator.TryGenerate(ifaceName, CancellationToken.None)); + this.CollectGeneratedCode(this.generator); + this.AssertNoDiagnostics(); + StructDeclarationSyntax structSyntax = Assert.Single(this.FindGeneratedType(ifaceName).OfType()); + + // Check a property where we expect just a getter. + PropertyDeclarationSyntax getProperty = Assert.Single(structSyntax.Members.OfType(), m => m.Identifier.ValueText == "PrimaryInterface"); + Assert.True(HasAccessor(getProperty, SyntaxKind.GetAccessorDeclaration)); + Assert.False(HasAccessor(getProperty, SyntaxKind.SetAccessorDeclaration)); + + // Check a property where we expect both a getter and setter. + PropertyDeclarationSyntax getSetProperty = Assert.Single(structSyntax.Members.OfType(), m => m.Identifier.ValueText == "CLSID"); + Assert.True(HasAccessor(getSetProperty, SyntaxKind.GetAccessorDeclaration)); + Assert.True(HasAccessor(getSetProperty, SyntaxKind.SetAccessorDeclaration)); + + bool HasAccessor(PropertyDeclarationSyntax property, SyntaxKind kind) => property.AccessorList?.Accessors.SingleOrDefault(a => a.IsKind(kind)) is not null; + } + + [Fact] + public void COMPropertiesAreGeneratedAsStructProperties_NonConsecutiveAccessors() + { + const string ifaceName = "IUIAutomationProxyFactoryEntry"; + this.generator = this.CreateGenerator(DefaultTestGeneratorOptions with { AllowMarshaling = false }); + Assert.True(this.generator.TryGenerate(ifaceName, CancellationToken.None)); + this.CollectGeneratedCode(this.generator); + this.AssertNoDiagnostics(); + StructDeclarationSyntax structSyntax = Assert.Single(this.FindGeneratedType(ifaceName).OfType()); + + // Check for a property where the interface declares the getter and setter in non-consecutive rows of the VMT. + // For structs, we can still declare both as accessors because we implement them, provided they have the same type. + PropertyDeclarationSyntax getSetProperty = Assert.Single(structSyntax.Members.OfType(), m => m.Identifier.ValueText == "CanCheckBaseClass"); + Assert.True(HasAccessor(getSetProperty, SyntaxKind.GetAccessorDeclaration)); + Assert.True(HasAccessor(getSetProperty, SyntaxKind.SetAccessorDeclaration)); + + // And in some cases, the types are *not* the same, so the first accessor gets the property, and subsequent ones get the method syntax. + PropertyDeclarationSyntax getProperty = Assert.Single(structSyntax.Members.OfType(), m => m.Identifier.ValueText == "ClassName"); + Assert.True(HasAccessor(getProperty, SyntaxKind.GetAccessorDeclaration)); + Assert.False(HasAccessor(getProperty, SyntaxKind.SetAccessorDeclaration)); + Assert.NotEmpty(structSyntax.Members.OfType().Where(m => m.Identifier.ValueText == "put_ClassName")); + + bool HasAccessor(PropertyDeclarationSyntax property, SyntaxKind kind) => property.AccessorList?.Accessors.SingleOrDefault(a => a.IsKind(kind)) is not null; + } + [Fact] public void WinRTInterfaceDoesntBringInMarshalerIfParamNotObject() { @@ -1320,9 +1429,11 @@ public void UnicodeExtenMethodsGetCharSet() Assert.Contains( generatedMethod.AttributeLists.SelectMany(al => al.Attributes), a => a.Name.ToString() == "DllImport" && - a.ArgumentList?.Arguments.Any(arg => arg is { + a.ArgumentList?.Arguments.Any(arg => arg is + { NameEquals.Name.Identifier.ValueText: nameof(DllImportAttribute.CharSet), - Expression: MemberAccessExpressionSyntax { Name: IdentifierNameSyntax { Identifier.ValueText: nameof(CharSet.Unicode) } } }) is true); + Expression: MemberAccessExpressionSyntax { Name: IdentifierNameSyntax { Identifier.ValueText: nameof(CharSet.Unicode) } } + }) is true); } [Fact] diff --git a/test/SpellChecker/Program.cs b/test/SpellChecker/Program.cs index 04893af2..cec164bb 100644 --- a/test/SpellChecker/Program.cs +++ b/test/SpellChecker/Program.cs @@ -38,12 +38,12 @@ break; } - uint startIndex = error.get_StartIndex(); - uint length = error.get_Length(); + uint startIndex = error.StartIndex; + uint length = error.Length; var word = text.Substring((int)startIndex, (int)length); - CORRECTIVE_ACTION action = error.get_CorrectiveAction(); + CORRECTIVE_ACTION action = error.CorrectiveAction; switch (action) { @@ -51,7 +51,7 @@ Console.WriteLine(@"Delete ""{0}""", word); break; case CORRECTIVE_ACTION.CORRECTIVE_ACTION_REPLACE: - PWSTR replacement = error.get_Replacement(); + PWSTR replacement = error.Replacement; Console.WriteLine(@"Replace ""{0}"" with ""{1}""", word, replacement); CoTaskMemFree(replacement); break;