From 97e60cfb56b809e5f0456a2159d994d8d5a863c3 Mon Sep 17 00:00:00 2001 From: elachlan <2433737+elachlan@users.noreply.github.com> Date: Tue, 16 Aug 2022 20:26:38 +1000 Subject: [PATCH] Honour DoNotRelease on SafeHandles --- src/Microsoft.Windows.CsWin32/Generator.cs | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.Windows.CsWin32/Generator.cs b/src/Microsoft.Windows.CsWin32/Generator.cs index 478935f0..62fb75b6 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.cs @@ -27,6 +27,7 @@ public class Generator : IDisposable internal const string InteropDecorationNamespace = "Windows.Win32.Interop"; internal const string NativeArrayInfoAttribute = "NativeArrayInfoAttribute"; internal const string RAIIFreeAttribute = "RAIIFreeAttribute"; + internal const string DoNotReleaseAttribute = "DoNotReleaseAttribute"; internal const string GlobalNamespacePrefix = "global::"; internal const string GlobalWinmdRootNamespaceAlias = "winmdroot"; internal const string WinRTCustomMarshalerClass = "WinRTCustomMarshaler"; @@ -1870,14 +1871,17 @@ internal void GetBaseTypeInfo(TypeDefinition typeDef, out StringHandle baseTypeN return null; } - internal CustomAttribute? FindInteropDecorativeAttribute(CustomAttributeHandleCollection customAttributeHandles, string attributeName) + internal CustomAttribute? FindInteropDecorativeAttribute(CustomAttributeHandleCollection? customAttributeHandles, string attributeName) { - foreach (CustomAttributeHandle handle in customAttributeHandles) + if (customAttributeHandles is not null) { - CustomAttribute att = this.Reader.GetCustomAttribute(handle); - if (this.IsAttribute(att, InteropDecorationNamespace, attributeName)) + foreach (CustomAttributeHandle handle in customAttributeHandles) { - return att; + CustomAttribute att = this.Reader.GetCustomAttribute(handle); + if (this.IsAttribute(att, InteropDecorationNamespace, attributeName)) + { + return att; + } } } @@ -4263,6 +4267,7 @@ private IEnumerable DeclareFriendlyOverloads(MethodDefi static ParameterSyntax StripAttributes(ParameterSyntax parameter) => parameter.WithAttributeLists(List()); static ExpressionSyntax GetSpanLength(ExpressionSyntax span) => MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, span, IdentifierName(nameof(Span.Length))); bool isReleaseMethod = this.MetadataIndex.ReleaseMethods.Contains(externMethodDeclaration.Identifier.ValueText); + bool doNotRelease = this.FindInteropDecorativeAttribute(this.GetReturnTypeCustomAttributes(methodDefinition), DoNotReleaseAttribute) is not null; TypeSyntaxSettings parameterTypeSyntaxSettings = overloadOf switch { @@ -4341,7 +4346,7 @@ private IEnumerable DeclareFriendlyOverloads(MethodDefi origName, ObjectCreationExpression(safeHandleType).AddArgumentListArguments( Argument(typeDefHandleName), - Argument(LiteralExpression(SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle"))))))); + Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle"))))))); } } else if (this.options.UseSafeHandles && isIn && !isOut && !isReleaseMethod && parameterTypeInfo is HandleTypeHandleInfo parameterHandleTypeInfo && this.TryGetHandleReleaseMethod(parameterHandleTypeInfo.Handle, out string? releaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, releaseMethod)) @@ -4733,7 +4738,7 @@ private IEnumerable DeclareFriendlyOverloads(MethodDefi //// return new SafeHandle(result, ownsHandle: true); body = body.AddStatements(ReturnStatement(ObjectCreationExpression(returnSafeHandleType).AddArgumentListArguments( Argument(resultLocal), - Argument(LiteralExpression(SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle")))))); + Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle")))))); } else if (hasVoidReturn) {