Skip to content

Commit

Permalink
Merge pull request #760 from microsoft/fix755
Browse files Browse the repository at this point in the history
Use documented invalid handles when SafeHandle is null
  • Loading branch information
AArnott authored Nov 11, 2022
2 parents 874445e + be494bf commit 25df5f1
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 9 deletions.
4 changes: 4 additions & 0 deletions src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla

internal static ThrowStatementSyntax ThrowStatement(ExpressionSyntax expression) => SyntaxFactory.ThrowStatement(Token(TriviaList(), SyntaxKind.ThrowKeyword, TriviaList(Space)), expression, Semicolon);

internal static ThrowExpressionSyntax ThrowExpression(ExpressionSyntax expression) => SyntaxFactory.ThrowExpression(Token(TriviaList(), SyntaxKind.ThrowKeyword, TriviaList(Space)), expression);

internal static ExpressionSyntax NameOfExpression(IdentifierNameSyntax identifierName) => SyntaxFactory.InvocationExpression(IdentifierName("nameof"), ArgumentList(SingletonSeparatedList(Argument(identifierName))));

internal static ReturnStatementSyntax ReturnStatement(ExpressionSyntax? expression) => SyntaxFactory.ReturnStatement(Token(TriviaList(), SyntaxKind.ReturnKeyword, TriviaList(Space)), expression!, Semicolon);

internal static DelegateDeclarationSyntax DelegateDeclaration(TypeSyntax returnType, SyntaxToken identifier) => SyntaxFactory.DelegateDeclaration(default(SyntaxList<AttributeListSyntax>), default(SyntaxTokenList), Token(TriviaList(), SyntaxKind.DelegateKeyword, TriviaList(Space)), returnType.WithTrailingTrivia(TriviaList(Space)), identifier, null, ParameterList(), default, Semicolon);
Expand Down
39 changes: 31 additions & 8 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ public class Generator : IDisposable
private const string SystemRuntimeInteropServices = "System.Runtime.InteropServices";
private const string NativeTypedefAttribute = "NativeTypedefAttribute";
private const string InvalidHandleValueAttribute = "InvalidHandleValueAttribute";
private const string CanReturnMultipleSuccessValuesAttribute = "CanReturnMultipleSuccessValuesAttribute";
private const string CanReturnErrorsAsSuccessAttribute = "CanReturnErrorsAsSuccessAttribute";
private const string SimpleFileNameAnnotation = "SimpleFileName";
private const string NamespaceContainerAnnotation = "NamespaceContainer";
private const string OriginalDelegateAnnotation = "OriginalDelegate";
Expand Down Expand Up @@ -1378,9 +1380,9 @@ nsContents.Key is object
nativeArrayInfo?.CountParamIndex.HasValue is true ? LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(nativeArrayInfo.Value.CountParamIndex.Value)) : null);
}

internal static TypeSyntax MakeSpanOfT(TypeSyntax typeArgument) => GenericName("Span").AddTypeArgumentListArguments(typeArgument);
internal static TypeSyntax MakeSpanOfT(TypeSyntax typeArgument) => GenericName(nameof(Span<int>)).AddTypeArgumentListArguments(typeArgument);

internal static TypeSyntax MakeReadOnlySpanOfT(TypeSyntax typeArgument) => GenericName("ReadOnlySpan").AddTypeArgumentListArguments(typeArgument);
internal static TypeSyntax MakeReadOnlySpanOfT(TypeSyntax typeArgument) => GenericName(nameof(ReadOnlySpan<int>)).AddTypeArgumentListArguments(typeArgument);

/// <summary>
/// Checks whether an exception was originally thrown because of a target platform incompatibility.
Expand Down Expand Up @@ -1754,7 +1756,7 @@ internal void RequestMacro(MethodDeclarationSyntax macro)
// Collect all the known invalid values for this handle.
// If no invalid values are given (e.g. BSTR), we'll just assume 0 is invalid.
HashSet<IntPtr> invalidHandleValues = this.GetInvalidHandleValues(((HandleTypeHandleInfo)releaseMethodParameterTypeHandleInfo).Handle);
long preferredInvalidValue = invalidHandleValues.Contains(new IntPtr(-1)) ? -1 : invalidHandleValues.FirstOrDefault().ToInt64();
IntPtr preferredInvalidValue = GetPreferredInvalidHandleValue(invalidHandleValues);

CustomAttributeHandleCollection? atts = this.GetReturnTypeCustomAttributes(releaseMethodDef);
TypeSyntaxAndMarshaling releaseMethodReturnType = releaseMethodSignature.ReturnType.ToTypeSyntax(this.externSignatureTypeSettings, atts);
Expand All @@ -1765,7 +1767,7 @@ internal void RequestMacro(MethodDeclarationSyntax macro)

MemberAccessExpressionSyntax thisHandle = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ThisExpression(), IdentifierName("handle"));
ExpressionSyntax intptrZero = DefaultExpression(IntPtrTypeSyntax);
ExpressionSyntax invalidHandleIntPtr = ObjectCreationExpression(IntPtrTypeSyntax).AddArgumentListArguments(Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(preferredInvalidValue))));
ExpressionSyntax invalidHandleIntPtr = IntPtrExpr(preferredInvalidValue);

// private static readonly IntPtr INVALID_HANDLE_VALUE = new IntPtr(-1);
IdentifierNameSyntax invalidValueFieldName = IdentifierName("INVALID_HANDLE_VALUE");
Expand Down Expand Up @@ -2247,6 +2249,8 @@ protected virtual void Dispose(bool disposing)

private static SyntaxToken TokenWithLineFeed(SyntaxKind syntaxKind) => SyntaxFactory.Token(TriviaList(), syntaxKind, TriviaList(LineFeed));

private static IntPtr GetPreferredInvalidHandleValue(HashSet<IntPtr> invalidHandleValues) => invalidHandleValues.Contains(new IntPtr(-1)) ? new IntPtr(-1) : invalidHandleValues.FirstOrDefault();

private static bool RequiresUnsafe(TypeSyntax? typeSyntax) => typeSyntax is PointerTypeSyntax || typeSyntax is FunctionPointerTypeSyntax;

private static string GetClassNameForModule(string moduleName) =>
Expand Down Expand Up @@ -2620,6 +2624,9 @@ private static Guid DecodeGuidFromAttribute(CustomAttribute guidAttribute)

private static bool IsHresult(TypeHandleInfo? typeHandleInfo) => typeHandleInfo is HandleTypeHandleInfo handleInfo && handleInfo.IsType("HRESULT");

private static ExpressionSyntax IntPtrExpr(IntPtr value) => ObjectCreationExpression(IntPtrTypeSyntax).AddArgumentListArguments(
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(value.ToInt64()))));

private T AddApiDocumentation<T>(string api, T memberDeclaration)
where T : MemberDeclarationSyntax
{
Expand Down Expand Up @@ -3784,8 +3791,8 @@ StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionSta
bool preserveSig = interfaceAsSubtype
|| !IsHresult(signature.ReturnType)
|| (methodDefinition.ImplAttributes & MethodImplAttributes.PreserveSig) == MethodImplAttributes.PreserveSig
|| this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), "CanReturnMultipleSuccessValuesAttribute") is not null
|| this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), "CanReturnErrorsAsSuccessAttribute") is not null
|| this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), CanReturnMultipleSuccessValuesAttribute) is not null
|| this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), CanReturnErrorsAsSuccessAttribute) is not null
|| this.options.ComInterop.PreserveSigMethods.Contains($"{ifaceName}.{methodName}")
|| this.options.ComInterop.PreserveSigMethods.Contains(ifaceName.ToString());

Expand Down Expand Up @@ -4715,6 +4722,22 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
leadingStatements.Add(LocalDeclarationStatement(VariableDeclaration(externParam.Type).AddVariables(
VariableDeclarator(typeDefHandleName.Identifier))));

// throw new ArgumentNullException(nameof(hTemplateFile));
StatementSyntax nullHandleStatement = ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentNullException))).WithArgumentList(ArgumentList().AddArguments(Argument(NameOfExpression(IdentifierName(externParam.Identifier.ValueText))))));
if (isOptional)
{
HashSet<IntPtr> invalidValues = this.GetInvalidHandleValues(parameterHandleTypeInfo.Handle);
if (invalidValues.Count > 0)
{
// (HANDLE)new IntPtr(-1);
IntPtr invalidValue = GetPreferredInvalidHandleValue(invalidValues);
ExpressionSyntax invalidExpression = CastExpression(externParam.Type, IntPtrExpr(invalidValue));

// hTemplateFileLocal = invalid-handle-value;
nullHandleStatement = ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, typeDefHandleName, invalidExpression));
}
}

// if (hTemplateFile is object)
leadingStatements.Add(IfStatement(
BinaryExpression(SyntaxKind.IsExpression, origName, PredefinedType(Token(SyntaxKind.ObjectKeyword))),
Expand All @@ -4732,7 +4755,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName(nameof(SafeHandle.DangerousGetHandle))), ArgumentList())))
.WithOperatorToken(TokenWithSpaces(SyntaxKind.EqualsToken)))),
//// else hTemplateFileLocal = default;
ElseClause(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, typeDefHandleName, DefaultExpression(externParam.Type.WithoutTrailingTrivia())).WithOperatorToken(TokenWithSpaces(SyntaxKind.EqualsToken))))));
ElseClause(nullHandleStatement)));

// if (hTemplateFileAddRef)
// hTemplateFile.DangerousRelease();
Expand Down Expand Up @@ -5753,7 +5776,7 @@ InvocationExpressionSyntax SliceAtLengthToString(ExpressionSyntax readOnlySpan)
BinaryExpression(SyntaxKind.LessThanExpression, lengthParameterName, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))),
BinaryExpression(SyntaxKind.GreaterThanExpression, lengthParameterName, lengthConstant)),
ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentOutOfRangeException))).AddArgumentListArguments(
Argument(InvocationExpression(IdentifierName("nameof"), ArgumentList().AddArguments(Argument(lengthParameterName)))),
Argument(NameOfExpression(lengthParameterName)),
Argument(lengthParameterName),
Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal("Length must be between 0 and the fixed array length, inclusive.")))))),
FixedBlock(
Expand Down
3 changes: 2 additions & 1 deletion test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ public void InterestingAPIs(
"PZZWSTR",
"PCZZSTR",
"PCZZWSTR",
"NCryptImportKey", // friendly overload takes SafeHandle backed by a UIntPtr instead of IntPtr
"IUIAutomation", // non-preservesig retval COM method with a array size index parameter
"IHTMLWindow2", // contains properties named with C# reserved keywords
"CreateFile", // built-in SafeHandle use
Expand Down Expand Up @@ -2653,7 +2654,7 @@ internal static unsafe Microsoft.Win32.SafeHandles.SafeFileHandle CreateFile(str
hTemplateFileLocal = (winmdroot.Foundation.HANDLE)hTemplateFile.DangerousGetHandle();
}}
else
hTemplateFileLocal = default(winmdroot.Foundation.HANDLE);
hTemplateFileLocal= (winmdroot.Foundation.HANDLE )new IntPtr(-1L);
winmdroot.Foundation.HANDLE __result = PInvoke.CreateFile(lpFileNameLocal, dwDesiredAccess, dwShareMode, lpSecurityAttributes.HasValue ? &lpSecurityAttributesLocal : null, dwCreationDisposition, dwFlagsAndAttributes, hTemplateFileLocal);
return new Microsoft.Win32.SafeHandles.SafeFileHandle(__result, ownsHandle: true);
}}
Expand Down

0 comments on commit 25df5f1

Please sign in to comment.