diff --git a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs index 65394c03..ba721d39 100644 --- a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs +++ b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs @@ -74,6 +74,8 @@ internal static SyntaxToken Token(SyntaxKind kind) internal static BlockSyntax Block(params StatementSyntax[] statements) => SyntaxFactory.Block(OpenBrace, List(statements), CloseBrace); + internal static ImplicitArrayCreationExpressionSyntax ImplicitArrayCreationExpression(InitializerExpressionSyntax initializerExpression) => SyntaxFactory.ImplicitArrayCreationExpression(Token(SyntaxKind.NewKeyword), Token(SyntaxKind.OpenBracketToken), default, Token(SyntaxKind.CloseBracketToken), initializerExpression); + internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? declaration, ExpressionSyntax condition, SeparatedSyntaxList incrementors, StatementSyntax statement) { SyntaxToken semicolonToken = SyntaxFactory.Token(TriviaList(), SyntaxKind.SemicolonToken, TriviaList(Space)); @@ -112,7 +114,11 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla internal static WhileStatementSyntax WhileStatement(ExpressionSyntax expression, StatementSyntax statement) => SyntaxFactory.WhileStatement(Token(TriviaList(), SyntaxKind.WhileKeyword, TriviaList(Space)), Token(SyntaxKind.OpenParenToken), expression, Token(TriviaList(), SyntaxKind.CloseParenToken, TriviaList(LineFeed)), statement); - internal static TryStatementSyntax TryStatement(BlockSyntax block, SyntaxList catches, FinallyClauseSyntax @finally) => SyntaxFactory.TryStatement(Token(TriviaList(), SyntaxKind.TryKeyword, TriviaList(LineFeed)), block, catches, @finally); + internal static TryStatementSyntax TryStatement(BlockSyntax block, SyntaxList catches, FinallyClauseSyntax? @finally) => SyntaxFactory.TryStatement(Token(TriviaList(), SyntaxKind.TryKeyword, TriviaList(LineFeed)), block, catches, @finally!); + + internal static CatchClauseSyntax CatchClause(CatchDeclarationSyntax? catchDeclaration, CatchFilterClauseSyntax? filter, BlockSyntax block) => SyntaxFactory.CatchClause(TokenWithSpace(SyntaxKind.CatchKeyword), catchDeclaration, filter, block); + + internal static CatchDeclarationSyntax CatchDeclaration(TypeSyntax type, SyntaxToken identifier) => SyntaxFactory.CatchDeclaration(Token(SyntaxKind.OpenParenToken), type, identifier, Token(SyntaxKind.CloseParenToken)); internal static SwitchSectionSyntax SwitchSection() => SyntaxFactory.SwitchSection(); @@ -257,7 +263,7 @@ internal static SyntaxToken XmlTextNewLine(string text, bool continueXmlDocument internal static MethodDeclarationSyntax MethodDeclaration(TypeSyntax returnType, SyntaxToken identifier) => SyntaxFactory.MethodDeclaration(default(SyntaxList), default(SyntaxTokenList), returnType.WithTrailingTrivia(TriviaList(Space)), null, identifier, null, ParameterList(), default(SyntaxList), null, null, default(SyntaxToken)); - internal static MethodDeclarationSyntax MethodDeclaration(SyntaxList attributeLists, SyntaxTokenList modifiers, TypeSyntax returnType, ExplicitInterfaceSpecifierSyntax explicitInterfaceSpecifier, SyntaxToken identifier, TypeParameterListSyntax typeParameterList, ParameterListSyntax parameterList, SyntaxList constraintClauses, BlockSyntax body, SyntaxToken semicolonToken) => SyntaxFactory.MethodDeclaration(attributeLists, modifiers, returnType.WithTrailingTrivia(TriviaList(Space)), explicitInterfaceSpecifier, identifier, typeParameterList, parameterList, constraintClauses, body, semicolonToken); + internal static MethodDeclarationSyntax MethodDeclaration(SyntaxList attributeLists, SyntaxTokenList modifiers, TypeSyntax returnType, ExplicitInterfaceSpecifierSyntax? explicitInterfaceSpecifier, SyntaxToken identifier, TypeParameterListSyntax? typeParameterList, ParameterListSyntax parameterList, SyntaxList constraintClauses, BlockSyntax body, SyntaxToken semicolonToken) => SyntaxFactory.MethodDeclaration(attributeLists, modifiers, returnType.WithTrailingTrivia(TriviaList(Space)), explicitInterfaceSpecifier!, identifier, typeParameterList!, parameterList, constraintClauses, body, semicolonToken); internal static MemberDeclarationSyntax? ParseMemberDeclaration(string text, ParseOptions? options) => SyntaxFactory.ParseMemberDeclaration(text, options: options); @@ -289,7 +295,7 @@ internal static SyntaxList List(IEnumerable nodes) internal static ArgumentListSyntax ArgumentList(SeparatedSyntaxList arguments = default) => SyntaxFactory.ArgumentList(Token(SyntaxKind.OpenParenToken), arguments, Token(SyntaxKind.CloseParenToken)); - internal static AssignmentExpressionSyntax AssignmentExpression(SyntaxKind kind, ExpressionSyntax left, ExpressionSyntax right) => SyntaxFactory.AssignmentExpression(kind, left, Token(GetAssignmentExpressionOperatorTokenKind(kind)), right); + internal static AssignmentExpressionSyntax AssignmentExpression(SyntaxKind kind, ExpressionSyntax left, ExpressionSyntax right) => SyntaxFactory.AssignmentExpression(kind, left, Token(GetAssignmentExpressionOperatorTokenKind(kind)).WithLeadingTrivia(Space), right); internal static ArgumentSyntax Argument(ExpressionSyntax expression) => SyntaxFactory.Argument(expression); diff --git a/src/Microsoft.Windows.CsWin32/Generator.Com.cs b/src/Microsoft.Windows.CsWin32/Generator.Com.cs index dc4d80d6..494ca02b 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Com.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Com.cs @@ -6,6 +6,12 @@ namespace Microsoft.Windows.CsWin32; public partial class Generator { private static readonly IdentifierNameSyntax HRThrowOnFailureMethodName = IdentifierName("ThrowOnFailure"); + + // [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })] + private static readonly AttributeListSyntax CcwEntrypointAttributes = AttributeList().AddAttributes(Attribute(IdentifierName("UnmanagedCallersOnly")).AddArgumentListArguments( + AttributeArgument(ImplicitArrayCreationExpression(InitializerExpression(SyntaxKind.ArrayInitializerExpression, SingletonSeparatedList(TypeOfExpression(IdentifierName("CallConvStdcall")))))) + .WithNameEquals(NameEquals(IdentifierName("CallConvs"))))); + private readonly HashSet injectedPInvokeHelperMethodsToFriendlyOverloadsExtensions = new(); private static Guid DecodeGuidFromAttribute(CustomAttribute guidAttribute) @@ -92,6 +98,14 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type var members = new List(); var vtblMembers = new List(); TypeSyntaxSettings typeSettings = this.comSignatureTypeSettings; + IdentifierNameSyntax pThisLocal = IdentifierName("pThis"); + ParameterSyntax? ccwThisParameter = this.canUseUnmanagedCallersOnlyAttribute && !this.options.AllowMarshaling && originalIfaceName != "IUnknown" && originalIfaceName != "IDispatch" ? Parameter(pThisLocal.Identifier).WithType(PointerType(ifaceName).WithTrailingTrivia(Space)) : null; + List ccwMethodsToSkip = new(); + IdentifierNameSyntax vtblParamName = IdentifierName("vtable"); + BlockSyntax populateVTableBody = Block(); + IdentifierNameSyntax objectLocal = IdentifierName("@object"); + IdentifierNameSyntax hrLocal = IdentifierName("hr"); + StatementSyntax returnSOK = ReturnStatement(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, HresultTypeSyntax, IdentifierName("S_OK"))); // It is imperative that we generate methods for all base interfaces as well, ahead of any implemented by *this* interface. var allMethods = new List(); @@ -100,7 +114,15 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type QualifiedTypeDefinitionHandle qualifiedBaseType = baseTypes.Peek(); baseTypes = baseTypes.Pop(); TypeDefinition baseType = qualifiedBaseType.Generator.Reader.GetTypeDefinition(qualifiedBaseType.DefinitionHandle); - allMethods.AddRange(baseType.GetMethods().Select(m => new QualifiedMethodDefinitionHandle(qualifiedBaseType.Generator, m))); + IEnumerable methodsThisType = baseType.GetMethods().Select(m => new QualifiedMethodDefinitionHandle(qualifiedBaseType.Generator, m)); + allMethods.AddRange(methodsThisType); + + // We do *not* emit CCW methods for IUnknown, because those are provided by ComWrappers. + if (ccwThisParameter is not null && + (qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IUnknown") || qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IDispatch"))) + { + ccwMethodsToSkip.AddRange(methodsThisType); + } } allMethods.AddRange(typeDef.GetMethods().Select(m => new QualifiedMethodDefinitionHandle(this, m))); @@ -121,8 +143,10 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type MethodSignature signature = methodDefinition.Method.DecodeSignature(SignatureHandleProvider.Instance, null); CustomAttributeHandleCollection? returnTypeAttributes = methodDefinition.Generator.GetReturnTypeCustomAttributes(methodDefinition.Method); TypeSyntax returnType = signature.ReturnType.ToTypeSyntax(typeSettings, returnTypeAttributes).Type; + TypeSyntax returnTypePreserveSig = returnType; ParameterListSyntax parameterList = methodDefinition.Generator.CreateParameterList(methodDefinition.Method, signature, typeSettings); + ParameterListSyntax parameterListPreserveSig = parameterList; // preserve a copy that has no mutations. FunctionPointerParameterListSyntax funcPtrParameters = FunctionPointerParameterList() .AddParameters(FunctionPointerParameter(PointerType(ifaceName))) .AddParameters(parameterList.Parameters.Select(p => FunctionPointerParameter(p.Type!).WithModifiers(p.Modifiers)).ToArray()) @@ -143,7 +167,6 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type // By doing this, we make the emitted code more trimmable by not referencing the full virtual method table and its full set of types // when the app may only invoke a subset of the methods. //// ((delegate *unmanaged [Stdcall])lpVtbl[3])(pThis, pClassID) - IdentifierNameSyntax pThisLocal = IdentifierName("pThis"); ExpressionSyntax vtblIndexingExpression = ParenthesizedExpression( CastExpression(unmanagedDelegateType, ElementAccessExpression(vtblFieldName).AddArgumentListArguments(Argument(methodOffset)))); InvocationExpressionSyntax vtblInvocation = InvocationExpression(vtblIndexingExpression) @@ -189,6 +212,18 @@ StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionSta resultLocalDeclaration, vtblInvocationStatement, returnStatement)).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword))); + + if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle)) + { + //// *inputArg = @object.Property; + StatementSyntax propertyGet = ExpressionStatement(AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, IdentifierName(parameterListPreserveSig.Parameters.Last().Identifier.ValueText)), + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, objectLocal, propertyName))); + this.TryGenerateConstantOrThrow("S_OK"); + AddCcwThunk(propertyGet, returnSOK); + } + break; case SyntaxKind.SetAccessorDeclaration: // vtblInvoke(pThis, value).ThrowOnFailure(); @@ -198,6 +233,18 @@ StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionSta VariableDeclaration(PointerType(ifaceName)).AddVariables( VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))), vtblInvocationStatement).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword))); + + if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle)) + { + //// @object.Property = inputArg; + StatementSyntax propertySet = ExpressionStatement(AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, objectLocal, propertyName), + IdentifierName(parameterListPreserveSig.Parameters.Last().Identifier.ValueText))); + this.TryGenerateConstantOrThrow("S_OK"); + AddCcwThunk(propertySet, returnSOK); + } + break; default: throw new NotSupportedException("Unsupported accessor kind: " + accessorKind); @@ -291,9 +338,9 @@ StatementSyntax InvokeVtblAndThrow() => ExpressionStatement(InvocationExpression List(), modifiers: TokenList(TokenWithSpace(SyntaxKind.PublicKeyword)), // always use public so struct can implement the COM interface returnType.WithTrailingTrivia(TriviaList(Space)), - explicitInterfaceSpecifier: null!, + explicitInterfaceSpecifier: null, SafeIdentifier(methodName), - null!, + null, parameterList, List(), body: body, @@ -312,6 +359,87 @@ StatementSyntax InvokeVtblAndThrow() => ExpressionStatement(InvocationExpression propertyOrMethod = methodDeclaration; members.AddRange(methodDefinition.Generator.DeclareFriendlyOverloads(methodDefinition.Method, methodDeclaration, IdentifierName(ifaceName.Identifier.ValueText), FriendlyOverloadOf.StructMethod, helperMethodsInStruct)); + + if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle)) + { + // Prepare the args for the thunk call. The Interface we thunk into *always* uses PreserveSig, which is super convenient for us. + ArgumentListSyntax args = ArgumentList().AddArguments(parameterListPreserveSig.Parameters.Select(p => Argument(IdentifierName(p.Identifier.ValueText))).ToArray()); + + // @object!.SomeMethod(args) + InvocationExpressionSyntax thunkInvoke = InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, objectLocal, SafeIdentifierName(methodName)), + args); + + StatementSyntax returnManagedMethodInvocation = returnTypePreserveSig is PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.VoidKeyword } + ? ExpressionStatement(thunkInvoke) + : ReturnStatement(thunkInvoke); + + AddCcwThunk(returnManagedMethodInvocation); + } + } + + void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn) + { + if (ccwThisParameter is null || ccwMethodsToSkip.Contains(methodDefHandle)) + { + return; + } + + this.RequestComHelpers(context); + bool hrReturnType = returnTypePreserveSig is QualifiedNameSyntax { Right.Identifier.ValueText: "HRESULT" }; + + //// HRESULT hr = ComHelpers.UnwrapCCW(@this, out Interface? @object); + LocalDeclarationStatementSyntax hrDecl = LocalDeclarationStatement(VariableDeclaration(HresultTypeSyntax).AddVariables( + VariableDeclarator(hrLocal.Identifier).WithInitializer(EqualsValueClause( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("ComHelpers"), IdentifierName("UnwrapCCW")), + ArgumentList().AddArguments( + Argument(pThisLocal), + Argument(DeclarationExpression(NestedCOMInterfaceName.WithTrailingTrivia(Space), SingleVariableDesignation(objectLocal.Identifier))).WithRefKindKeyword(Token(SyntaxKind.OutKeyword)))))))); + + StatementSyntax ifNullReturnStatement = hrReturnType + //// if (hr.Failed) return hr; + ? IfStatement( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrLocal, IdentifierName("Failed")), + Block().AddStatements(ReturnStatement(hrLocal))) + //// hr.ThrowOnFailure(); + : ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrLocal, HRThrowOnFailureMethodName))); + + //// catch (Exception ex) { return (HRESULT)ex.HResult; } + IdentifierNameSyntax exLocal = IdentifierName("ex"); + CatchClauseSyntax catchClause = CatchClause(CatchDeclaration(IdentifierName(nameof(Exception)).WithTrailingTrivia(Space), exLocal.Identifier), null, Block().AddStatements( + ReturnStatement(CastExpression(HresultTypeSyntax, MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, exLocal, IdentifierName(nameof(Exception.HResult))))))); + + BlockSyntax tryBlock = Block().AddStatements( + hrDecl, + ifNullReturnStatement).AddStatements(thunkInvokeAndReturn); + + BlockSyntax ccwBody = hrReturnType + //// try { ... } catch { ... } + ? Block().AddStatements(TryStatement(tryBlock, new SyntaxList(catchClause), null)) + //// { .... } // any exception is thrown back to native code. + : tryBlock; + + //// [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })] + //// private static HRESULT Clone(IEnumEventObject* @this, IEnumEventObject** ppInterface) + MethodDeclarationSyntax ccwMethod = MethodDeclaration( + new SyntaxList(CcwEntrypointAttributes), + TokenList(TokenWithSpace(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword)), + returnTypePreserveSig, + explicitInterfaceSpecifier: null, + SafeIdentifier(methodName), + typeParameterList: null, + ParameterList().WithParameters(parameterListPreserveSig.Parameters.Insert(0, ccwThisParameter)), + constraintClauses: default, + ccwBody, + semicolonToken: default); + members.Add(ccwMethod); + + populateVTableBody = populateVTableBody.AddStatements( + ExpressionStatement(AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + MemberAccessExpression(SyntaxKind.PointerMemberAccessExpression, vtblParamName, innerMethodName), + PrefixUnaryExpression(SyntaxKind.AddressOfExpression, SafeIdentifierName(methodName))))); } // Add documentation if we can find it. @@ -319,11 +447,28 @@ StatementSyntax InvokeVtblAndThrow() => ExpressionStatement(InvocationExpression members.Add(propertyOrMethod); } - // We expose the vtbl struct, not because we expect folks to use it directly, but because some folks may use it to manually generate CCWs. - StructDeclarationSyntax? vtblStruct = StructDeclaration(Identifier("Vtbl")) - .AddMembers(vtblMembers.ToArray()) - .AddModifiers(TokenWithSpace(this.Visibility)); - members.Add(vtblStruct); + if (ccwThisParameter is not null) + { + // We expose the vtbl struct to support CCWs + IdentifierNameSyntax vtblStructName = IdentifierName("Vtbl"); + StructDeclarationSyntax? vtblStruct = StructDeclaration(Identifier("Vtbl")).WithTrailingTrivia(Space) + .AddMembers(vtblMembers.ToArray()) + .AddModifiers(TokenWithSpace(this.Visibility)); + members.Add(vtblStruct); + + // internal static void PopulateVTable(Vtbl* vtable) + MethodDeclarationSyntax populateVtblMethodDecl = MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), Identifier("PopulateVTable")) + .AddModifiers(Token(this.Visibility), Token(SyntaxKind.StaticKeyword)) + .AddParameterListParameters(Parameter(vtblParamName.Identifier).WithType(PointerType(vtblStructName).WithTrailingTrivia(Space))) + .WithBody(populateVTableBody); + members.Add(populateVtblMethodDecl); + + if (populateVTableBody.Statements.Count != allMethods.Count - ccwMethodsToSkip.Count) + { + // We failed to initialize all the necessary vtbl entries. + throw new GenerationFailedException("Internal error while generating CCW vtbl initializer."); + } + } // private void** lpVtbl; // Vtbl* (but we avoid strong typing to enable trimming the entire vtbl struct away) members.Add(FieldDeclaration(VariableDeclaration(PointerType(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))))).AddVariables(VariableDeclarator(vtblFieldName.Identifier))).AddModifiers(TokenWithSpace(SyntaxKind.PrivateKeyword))); diff --git a/src/Microsoft.Windows.CsWin32/Generator.Features.cs b/src/Microsoft.Windows.CsWin32/Generator.Features.cs index 9a9331a2..115b8537 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Features.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Features.cs @@ -9,6 +9,7 @@ public partial class Generator private readonly bool canCallCreateSpan; private readonly bool canUseUnsafeAsRef; private readonly bool canUseUnsafeNullRef; + private readonly bool canUseUnmanagedCallersOnlyAttribute; private readonly bool unscopedRefAttributePredefined; private readonly INamedTypeSymbol? runtimeFeatureClass; private readonly bool generateSupportedOSPlatformAttributes; diff --git a/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs b/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs index 6d08fd02..8eebc3f5 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Invariants.cs @@ -77,6 +77,7 @@ public partial class Generator private const string OriginalDelegateAnnotation = "OriginalDelegate"; private static readonly Dictionary PInvokeHelperMethods; + private static readonly ClassDeclarationSyntax ComHelperClass; private static readonly Dictionary PInvokeMacros; private static readonly string AutoGeneratedHeader = @"// ------------------------------------------------------------------------------ @@ -302,7 +303,7 @@ public partial class Generator AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(ThisAssembly.AssemblyName))), AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(ThisAssembly.AssemblyInformationalVersion))))); - private static readonly TypeSyntax HresultTypeSyntax = IdentifierName("HRESULT"); + private static readonly TypeSyntax HresultTypeSyntax = QualifiedName(QualifiedName(IdentifierName(GlobalWinmdRootNamespaceAlias), IdentifierName("Foundation")), IdentifierName("HRESULT")); /// /// Gets the set of macros that can be generated. diff --git a/src/Microsoft.Windows.CsWin32/Generator.WhitespaceRewriter.cs b/src/Microsoft.Windows.CsWin32/Generator.WhitespaceRewriter.cs index ca75257a..8ba246e7 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.WhitespaceRewriter.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.WhitespaceRewriter.cs @@ -206,6 +206,8 @@ internal WhitespaceRewriter() public override SyntaxNode? VisitTryStatement(TryStatementSyntax node) => base.VisitTryStatement(this.WithIndentingTrivia(node)); + public override SyntaxNode? VisitCatchClause(CatchClauseSyntax node) => base.VisitCatchClause(this.WithIndentingTrivia(node)); + public override SyntaxNode? VisitFinallyClause(FinallyClauseSyntax node) => base.VisitFinallyClause(this.WithIndentingTrivia(node)); public override SyntaxNode? VisitIfStatement(IfStatementSyntax node) diff --git a/src/Microsoft.Windows.CsWin32/Generator.cs b/src/Microsoft.Windows.CsWin32/Generator.cs index e39e6af5..d8f333fe 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.cs @@ -52,6 +52,13 @@ static Generator() } PInvokeMacros = ((ClassDeclarationSyntax)member).Members.OfType().ToDictionary(m => m.Identifier.ValueText, m => m); + + if (!TryFetchTemplate("ComHelpers", null, out member)) + { + throw new GenerationFailedException("Missing embedded resource."); + } + + ComHelperClass = (ClassDeclarationSyntax)member; } /// @@ -79,6 +86,7 @@ public Generator(string metadataLibraryPath, Docs? docs, GeneratorOptions option this.canCallCreateSpan = this.compilation?.GetTypeByMetadataName(typeof(MemoryMarshal).FullName)?.GetMembers("CreateSpan").Any() is true; this.canUseUnsafeAsRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("AsRef").Any() is true; this.canUseUnsafeNullRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("NullRef").Any() is true; + this.canUseUnmanagedCallersOnlyAttribute = this.compilation?.GetTypeByMetadataName("System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute") is not null; this.unscopedRefAttributePredefined = this.FindTypeSymbolIfAlreadyAvailable("System.Diagnostics.CodeAnalysis.UnscopedRefAttribute") is not null; this.runtimeFeatureClass = (INamedTypeSymbol?)this.FindTypeSymbolIfAlreadyAvailable("System.Runtime.CompilerServices.RuntimeFeature"); this.comIIDInterfacePredefined = this.FindTypeSymbolIfAlreadyAvailable($"{this.Namespace}.{IComIIDGuidInterfaceName}") is not null; @@ -768,6 +776,13 @@ internal static string ReplaceCommonNamespaceWithAlias(Generator? generator, str return generator is object && generator.TryStripCommonNamespace(fullNamespace, out string? stripped) ? (stripped.Length > 0 ? $"{GlobalWinmdRootNamespaceAlias}.{stripped}" : GlobalWinmdRootNamespaceAlias) : $"global::{fullNamespace}"; } + internal void RequestComHelpers(Context context) + { + const string specialType = "ComHelpers"; + this.RequestInteropType("Windows.Win32.Foundation", "HRESULT", context); + this.volatileCode.GenerateSpecialType(specialType, () => this.volatileCode.AddSpecialType(specialType, ComHelperClass)); + } + internal bool TryStripCommonNamespace(string fullNamespace, [NotNullWhen(true)] out string? strippedNamespace) { if (fullNamespace.StartsWith(this.MetadataIndex.CommonNamespaceDot, StringComparison.Ordinal)) diff --git a/src/Microsoft.Windows.CsWin32/templates/ComHelpers.cs b/src/Microsoft.Windows.CsWin32/templates/ComHelpers.cs new file mode 100644 index 00000000..318cc38e --- /dev/null +++ b/src/Microsoft.Windows.CsWin32/templates/ComHelpers.cs @@ -0,0 +1,13 @@ +internal static unsafe partial class ComHelpers +{ + private static readonly winmdroot.Foundation.HRESULT COR_E_OBJECTDISPOSED = (winmdroot.Foundation.HRESULT)unchecked((int)0x80131622); + private static readonly winmdroot.Foundation.HRESULT S_OK = (winmdroot.Foundation.HRESULT)0; + + internal static winmdroot.Foundation.HRESULT UnwrapCCW(TThis* @this, out TInterface @object) + where TThis : unmanaged + where TInterface : class + { + @object = ComWrappers.ComInterfaceDispatch.GetInstance((ComWrappers.ComInterfaceDispatch*)@this); + return @object is null ? COR_E_OBJECTDISPOSED : S_OK; + } +}