Skip to content

Commit

Permalink
Use documented invalid handles when SafeHandle is null
Browse files Browse the repository at this point in the history
Also throw when the handle is not optional, but null is provided.

Fixes #755
  • Loading branch information
AArnott committed Nov 11, 2022
1 parent 63282ef commit 0d6b5a1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
27 changes: 24 additions & 3 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1756,7 +1756,7 @@ internal void RequestMacro(MethodDeclarationSyntax macro)
// Collect all the known invalid values for this handle.
// If no invalid values are given (e.g. BSTR), we'll just assume 0 is invalid.
HashSet<IntPtr> invalidHandleValues = this.GetInvalidHandleValues(((HandleTypeHandleInfo)releaseMethodParameterTypeHandleInfo).Handle);
long preferredInvalidValue = invalidHandleValues.Contains(new IntPtr(-1)) ? -1 : invalidHandleValues.FirstOrDefault().ToInt64();
IntPtr preferredInvalidValue = GetPreferredInvalidHandleValue(invalidHandleValues);

CustomAttributeHandleCollection? atts = this.GetReturnTypeCustomAttributes(releaseMethodDef);
TypeSyntaxAndMarshaling releaseMethodReturnType = releaseMethodSignature.ReturnType.ToTypeSyntax(this.externSignatureTypeSettings, atts);
Expand All @@ -1767,7 +1767,7 @@ internal void RequestMacro(MethodDeclarationSyntax macro)

MemberAccessExpressionSyntax thisHandle = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ThisExpression(), IdentifierName("handle"));
ExpressionSyntax intptrZero = DefaultExpression(IntPtrTypeSyntax);
ExpressionSyntax invalidHandleIntPtr = ObjectCreationExpression(IntPtrTypeSyntax).AddArgumentListArguments(Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(preferredInvalidValue))));
ExpressionSyntax invalidHandleIntPtr = IntPtrExpr(preferredInvalidValue);

// private static readonly IntPtr INVALID_HANDLE_VALUE = new IntPtr(-1);
IdentifierNameSyntax invalidValueFieldName = IdentifierName("INVALID_HANDLE_VALUE");
Expand Down Expand Up @@ -2249,6 +2249,8 @@ protected virtual void Dispose(bool disposing)

private static SyntaxToken TokenWithLineFeed(SyntaxKind syntaxKind) => SyntaxFactory.Token(TriviaList(), syntaxKind, TriviaList(LineFeed));

private static IntPtr GetPreferredInvalidHandleValue(HashSet<IntPtr> invalidHandleValues) => invalidHandleValues.Contains(new IntPtr(-1)) ? new IntPtr(-1) : invalidHandleValues.FirstOrDefault();

private static bool RequiresUnsafe(TypeSyntax? typeSyntax) => typeSyntax is PointerTypeSyntax || typeSyntax is FunctionPointerTypeSyntax;

private static string GetClassNameForModule(string moduleName) =>
Expand Down Expand Up @@ -2622,6 +2624,9 @@ private static Guid DecodeGuidFromAttribute(CustomAttribute guidAttribute)

private static bool IsHresult(TypeHandleInfo? typeHandleInfo) => typeHandleInfo is HandleTypeHandleInfo handleInfo && handleInfo.IsType("HRESULT");

private static ExpressionSyntax IntPtrExpr(IntPtr value) => ObjectCreationExpression(IntPtrTypeSyntax).AddArgumentListArguments(
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(value.ToInt64()))));

private T AddApiDocumentation<T>(string api, T memberDeclaration)
where T : MemberDeclarationSyntax
{
Expand Down Expand Up @@ -4717,6 +4722,22 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
leadingStatements.Add(LocalDeclarationStatement(VariableDeclaration(externParam.Type).AddVariables(
VariableDeclarator(typeDefHandleName.Identifier))));

// throw new ArgumentNullException(nameof(hTemplateFile));
StatementSyntax nullHandleStatement = ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentNullException))).WithArgumentList(ArgumentList().AddArguments(Argument(NameOfExpression(IdentifierName(externParam.Identifier.ValueText))))));
if (isOptional)
{
HashSet<IntPtr> invalidValues = this.GetInvalidHandleValues(parameterHandleTypeInfo.Handle);
if (invalidValues.Count > 0)
{
// (HANDLE)new IntPtr(-1);
IntPtr invalidValue = GetPreferredInvalidHandleValue(invalidValues);
ExpressionSyntax invalidExpression = CastExpression(externParam.Type, IntPtrExpr(invalidValue));

// hTemplateFileLocal = invalid-handle-value;
nullHandleStatement = ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, typeDefHandleName, invalidExpression));
}
}

// if (hTemplateFile is object)
leadingStatements.Add(IfStatement(
BinaryExpression(SyntaxKind.IsExpression, origName, PredefinedType(Token(SyntaxKind.ObjectKeyword))),
Expand All @@ -4734,7 +4755,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName(nameof(SafeHandle.DangerousGetHandle))), ArgumentList())))
.WithOperatorToken(TokenWithSpaces(SyntaxKind.EqualsToken)))),
//// else hTemplateFileLocal = default;
ElseClause(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, typeDefHandleName, DefaultExpression(externParam.Type.WithoutTrailingTrivia())).WithOperatorToken(TokenWithSpaces(SyntaxKind.EqualsToken))))));
ElseClause(nullHandleStatement)));

// if (hTemplateFileAddRef)
// hTemplateFile.DangerousRelease();
Expand Down
3 changes: 2 additions & 1 deletion test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ public void InterestingAPIs(
"PZZWSTR",
"PCZZSTR",
"PCZZWSTR",
"NCryptImportKey", // friendly overload takes SafeHandle backed by a UIntPtr instead of IntPtr
"IUIAutomation", // non-preservesig retval COM method with a array size index parameter
"IHTMLWindow2", // contains properties named with C# reserved keywords
"CreateFile", // built-in SafeHandle use
Expand Down Expand Up @@ -2653,7 +2654,7 @@ internal static unsafe Microsoft.Win32.SafeHandles.SafeFileHandle CreateFile(str
hTemplateFileLocal = (winmdroot.Foundation.HANDLE)hTemplateFile.DangerousGetHandle();
}}
else
hTemplateFileLocal = default(winmdroot.Foundation.HANDLE);
hTemplateFileLocal= new winmdroot.Foundation.HANDLE (new IntPtr(-1L));
winmdroot.Foundation.HANDLE __result = PInvoke.CreateFile(lpFileNameLocal, dwDesiredAccess, dwShareMode, lpSecurityAttributes.HasValue ? &lpSecurityAttributesLocal : null, dwCreationDisposition, dwFlagsAndAttributes, hTemplateFileLocal);
return new Microsoft.Win32.SafeHandles.SafeFileHandle(__result, ownsHandle: true);
}}
Expand Down

0 comments on commit 0d6b5a1

Please sign in to comment.