Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use documented invalid handles when SafeHandle is null #760

Merged
merged 3 commits into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla

internal static ThrowStatementSyntax ThrowStatement(ExpressionSyntax expression) => SyntaxFactory.ThrowStatement(Token(TriviaList(), SyntaxKind.ThrowKeyword, TriviaList(Space)), expression, Semicolon);

internal static ThrowExpressionSyntax ThrowExpression(ExpressionSyntax expression) => SyntaxFactory.ThrowExpression(Token(TriviaList(), SyntaxKind.ThrowKeyword, TriviaList(Space)), expression);

internal static ExpressionSyntax NameOfExpression(IdentifierNameSyntax identifierName) => SyntaxFactory.InvocationExpression(IdentifierName("nameof"), ArgumentList(SingletonSeparatedList(Argument(identifierName))));

internal static ReturnStatementSyntax ReturnStatement(ExpressionSyntax? expression) => SyntaxFactory.ReturnStatement(Token(TriviaList(), SyntaxKind.ReturnKeyword, TriviaList(Space)), expression!, Semicolon);

internal static DelegateDeclarationSyntax DelegateDeclaration(TypeSyntax returnType, SyntaxToken identifier) => SyntaxFactory.DelegateDeclaration(default(SyntaxList<AttributeListSyntax>), default(SyntaxTokenList), Token(TriviaList(), SyntaxKind.DelegateKeyword, TriviaList(Space)), returnType.WithTrailingTrivia(TriviaList(Space)), identifier, null, ParameterList(), default, Semicolon);
Expand Down
39 changes: 31 additions & 8 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ public class Generator : IDisposable
private const string SystemRuntimeInteropServices = "System.Runtime.InteropServices";
private const string NativeTypedefAttribute = "NativeTypedefAttribute";
private const string InvalidHandleValueAttribute = "InvalidHandleValueAttribute";
private const string CanReturnMultipleSuccessValuesAttribute = "CanReturnMultipleSuccessValuesAttribute";
private const string CanReturnErrorsAsSuccessAttribute = "CanReturnErrorsAsSuccessAttribute";
private const string SimpleFileNameAnnotation = "SimpleFileName";
private const string NamespaceContainerAnnotation = "NamespaceContainer";
private const string OriginalDelegateAnnotation = "OriginalDelegate";
Expand Down Expand Up @@ -1378,9 +1380,9 @@ nsContents.Key is object
nativeArrayInfo?.CountParamIndex.HasValue is true ? LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(nativeArrayInfo.Value.CountParamIndex.Value)) : null);
}

internal static TypeSyntax MakeSpanOfT(TypeSyntax typeArgument) => GenericName("Span").AddTypeArgumentListArguments(typeArgument);
internal static TypeSyntax MakeSpanOfT(TypeSyntax typeArgument) => GenericName(nameof(Span<int>)).AddTypeArgumentListArguments(typeArgument);

internal static TypeSyntax MakeReadOnlySpanOfT(TypeSyntax typeArgument) => GenericName("ReadOnlySpan").AddTypeArgumentListArguments(typeArgument);
internal static TypeSyntax MakeReadOnlySpanOfT(TypeSyntax typeArgument) => GenericName(nameof(ReadOnlySpan<int>)).AddTypeArgumentListArguments(typeArgument);

/// <summary>
/// Checks whether an exception was originally thrown because of a target platform incompatibility.
Expand Down Expand Up @@ -1754,7 +1756,7 @@ internal void RequestMacro(MethodDeclarationSyntax macro)
// Collect all the known invalid values for this handle.
// If no invalid values are given (e.g. BSTR), we'll just assume 0 is invalid.
HashSet<IntPtr> invalidHandleValues = this.GetInvalidHandleValues(((HandleTypeHandleInfo)releaseMethodParameterTypeHandleInfo).Handle);
long preferredInvalidValue = invalidHandleValues.Contains(new IntPtr(-1)) ? -1 : invalidHandleValues.FirstOrDefault().ToInt64();
IntPtr preferredInvalidValue = GetPreferredInvalidHandleValue(invalidHandleValues);

CustomAttributeHandleCollection? atts = this.GetReturnTypeCustomAttributes(releaseMethodDef);
TypeSyntaxAndMarshaling releaseMethodReturnType = releaseMethodSignature.ReturnType.ToTypeSyntax(this.externSignatureTypeSettings, atts);
Expand All @@ -1765,7 +1767,7 @@ internal void RequestMacro(MethodDeclarationSyntax macro)

MemberAccessExpressionSyntax thisHandle = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ThisExpression(), IdentifierName("handle"));
ExpressionSyntax intptrZero = DefaultExpression(IntPtrTypeSyntax);
ExpressionSyntax invalidHandleIntPtr = ObjectCreationExpression(IntPtrTypeSyntax).AddArgumentListArguments(Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(preferredInvalidValue))));
ExpressionSyntax invalidHandleIntPtr = IntPtrExpr(preferredInvalidValue);

// private static readonly IntPtr INVALID_HANDLE_VALUE = new IntPtr(-1);
IdentifierNameSyntax invalidValueFieldName = IdentifierName("INVALID_HANDLE_VALUE");
Expand Down Expand Up @@ -2247,6 +2249,8 @@ protected virtual void Dispose(bool disposing)

private static SyntaxToken TokenWithLineFeed(SyntaxKind syntaxKind) => SyntaxFactory.Token(TriviaList(), syntaxKind, TriviaList(LineFeed));

private static IntPtr GetPreferredInvalidHandleValue(HashSet<IntPtr> invalidHandleValues) => invalidHandleValues.Contains(new IntPtr(-1)) ? new IntPtr(-1) : invalidHandleValues.FirstOrDefault();

private static bool RequiresUnsafe(TypeSyntax? typeSyntax) => typeSyntax is PointerTypeSyntax || typeSyntax is FunctionPointerTypeSyntax;

private static string GetClassNameForModule(string moduleName) =>
Expand Down Expand Up @@ -2620,6 +2624,9 @@ private static Guid DecodeGuidFromAttribute(CustomAttribute guidAttribute)

private static bool IsHresult(TypeHandleInfo? typeHandleInfo) => typeHandleInfo is HandleTypeHandleInfo handleInfo && handleInfo.IsType("HRESULT");

private static ExpressionSyntax IntPtrExpr(IntPtr value) => ObjectCreationExpression(IntPtrTypeSyntax).AddArgumentListArguments(
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(value.ToInt64()))));

private T AddApiDocumentation<T>(string api, T memberDeclaration)
where T : MemberDeclarationSyntax
{
Expand Down Expand Up @@ -3784,8 +3791,8 @@ StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionSta
bool preserveSig = interfaceAsSubtype
|| !IsHresult(signature.ReturnType)
|| (methodDefinition.ImplAttributes & MethodImplAttributes.PreserveSig) == MethodImplAttributes.PreserveSig
|| this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), "CanReturnMultipleSuccessValuesAttribute") is not null
|| this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), "CanReturnErrorsAsSuccessAttribute") is not null
|| this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), CanReturnMultipleSuccessValuesAttribute) is not null
|| this.FindInteropDecorativeAttribute(methodDefinition.GetCustomAttributes(), CanReturnErrorsAsSuccessAttribute) is not null
|| this.options.ComInterop.PreserveSigMethods.Contains($"{ifaceName}.{methodName}")
|| this.options.ComInterop.PreserveSigMethods.Contains(ifaceName.ToString());

Expand Down Expand Up @@ -4715,6 +4722,22 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
leadingStatements.Add(LocalDeclarationStatement(VariableDeclaration(externParam.Type).AddVariables(
VariableDeclarator(typeDefHandleName.Identifier))));

// throw new ArgumentNullException(nameof(hTemplateFile));
StatementSyntax nullHandleStatement = ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentNullException))).WithArgumentList(ArgumentList().AddArguments(Argument(NameOfExpression(IdentifierName(externParam.Identifier.ValueText))))));
if (isOptional)
{
HashSet<IntPtr> invalidValues = this.GetInvalidHandleValues(parameterHandleTypeInfo.Handle);
if (invalidValues.Count > 0)
{
// (HANDLE)new IntPtr(-1);
IntPtr invalidValue = GetPreferredInvalidHandleValue(invalidValues);
ExpressionSyntax invalidExpression = CastExpression(externParam.Type, IntPtrExpr(invalidValue));

// hTemplateFileLocal = invalid-handle-value;
nullHandleStatement = ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, typeDefHandleName, invalidExpression));
}
}

// if (hTemplateFile is object)
leadingStatements.Add(IfStatement(
BinaryExpression(SyntaxKind.IsExpression, origName, PredefinedType(Token(SyntaxKind.ObjectKeyword))),
Expand All @@ -4732,7 +4755,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, origName, IdentifierName(nameof(SafeHandle.DangerousGetHandle))), ArgumentList())))
.WithOperatorToken(TokenWithSpaces(SyntaxKind.EqualsToken)))),
//// else hTemplateFileLocal = default;
ElseClause(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, typeDefHandleName, DefaultExpression(externParam.Type.WithoutTrailingTrivia())).WithOperatorToken(TokenWithSpaces(SyntaxKind.EqualsToken))))));
ElseClause(nullHandleStatement)));

// if (hTemplateFileAddRef)
// hTemplateFile.DangerousRelease();
Expand Down Expand Up @@ -5753,7 +5776,7 @@ InvocationExpressionSyntax SliceAtLengthToString(ExpressionSyntax readOnlySpan)
BinaryExpression(SyntaxKind.LessThanExpression, lengthParameterName, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))),
BinaryExpression(SyntaxKind.GreaterThanExpression, lengthParameterName, lengthConstant)),
ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentOutOfRangeException))).AddArgumentListArguments(
Argument(InvocationExpression(IdentifierName("nameof"), ArgumentList().AddArguments(Argument(lengthParameterName)))),
Argument(NameOfExpression(lengthParameterName)),
Argument(lengthParameterName),
Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal("Length must be between 0 and the fixed array length, inclusive.")))))),
FixedBlock(
Expand Down
3 changes: 2 additions & 1 deletion test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ public void InterestingAPIs(
"PZZWSTR",
"PCZZSTR",
"PCZZWSTR",
"NCryptImportKey", // friendly overload takes SafeHandle backed by a UIntPtr instead of IntPtr
"IUIAutomation", // non-preservesig retval COM method with a array size index parameter
"IHTMLWindow2", // contains properties named with C# reserved keywords
"CreateFile", // built-in SafeHandle use
Expand Down Expand Up @@ -2653,7 +2654,7 @@ internal static unsafe Microsoft.Win32.SafeHandles.SafeFileHandle CreateFile(str
hTemplateFileLocal = (winmdroot.Foundation.HANDLE)hTemplateFile.DangerousGetHandle();
}}
else
hTemplateFileLocal = default(winmdroot.Foundation.HANDLE);
hTemplateFileLocal= (winmdroot.Foundation.HANDLE )new IntPtr(-1L);
winmdroot.Foundation.HANDLE __result = PInvoke.CreateFile(lpFileNameLocal, dwDesiredAccess, dwShareMode, lpSecurityAttributes.HasValue ? &lpSecurityAttributesLocal : null, dwCreationDisposition, dwFlagsAndAttributes, hTemplateFileLocal);
return new Microsoft.Win32.SafeHandles.SafeFileHandle(__result, ownsHandle: true);
}}
Expand Down