Skip to content

Commit

Permalink
Fix NRE thrown from some COM extension methods
Browse files Browse the repository at this point in the history
Fixes #1041
  • Loading branch information
AArnott committed Sep 20, 2023
1 parent 4e9fbe2 commit 24fb7be
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
4 changes: 4 additions & 0 deletions src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla

internal static MemberAccessExpressionSyntax MemberAccessExpression(SyntaxKind kind, ExpressionSyntax expression, SimpleNameSyntax name) => SyntaxFactory.MemberAccessExpression(kind, expression, Token(GetMemberAccessExpressionOperatorTokenKind(kind)), name);

internal static ConditionalAccessExpressionSyntax ConditionalAccessExpression(ExpressionSyntax expression, SimpleNameSyntax name) => SyntaxFactory.ConditionalAccessExpression(expression, Token(SyntaxKind.QuestionToken), MemberBindingExpression(name));

internal static MemberBindingExpressionSyntax MemberBindingExpression(SimpleNameSyntax name) => SyntaxFactory.MemberBindingExpression(Token(SyntaxKind.DotToken), name);

internal static NameColonSyntax NameColon(IdentifierNameSyntax name) => SyntaxFactory.NameColon(name, Token(TriviaList(), SyntaxKind.ColonToken, TriviaList(Space)));

internal static UsingDirectiveSyntax UsingDirective(NameSyntax name) => SyntaxFactory.UsingDirective(Token(TriviaList(), SyntaxKind.UsingKeyword, TriviaList(Space)), default, null, name, Semicolon);
Expand Down
29 changes: 15 additions & 14 deletions src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi

#pragma warning disable SA1114 // Parameter list should follow declaration
static ParameterSyntax StripAttributes(ParameterSyntax parameter) => parameter.WithAttributeLists(List<AttributeListSyntax>());
static ExpressionSyntax GetSpanLength(ExpressionSyntax span) => MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, span, IdentifierName(nameof(Span<int>.Length)));
static ExpressionSyntax GetSpanLength(ExpressionSyntax span, bool isRefType) => isRefType ? ParenthesizedExpression(BinaryExpression(SyntaxKind.CoalesceExpression, ConditionalAccessExpression(span, IdentifierName(nameof(Span<int>.Length))), LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)))) : MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, span, IdentifierName(nameof(Span<int>.Length)));
bool isReleaseMethod = this.MetadataIndex.ReleaseMethods.Contains(externMethodDeclaration.Identifier.ValueText);
bool doNotRelease = this.FindInteropDecorativeAttribute(this.GetReturnTypeCustomAttributes(methodDefinition), DoNotReleaseAttribute) is not null;

Expand Down Expand Up @@ -270,6 +270,16 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
&& !isPointerToPointer)
{
signatureChanged = true;
bool remainsRefType = true;
if (externParam.Type is PointerTypeSyntax)
{
remainsRefType = false;
parameters[param.SequenceNumber - 1] = parameters[param.SequenceNumber - 1]
.WithType((isIn && isConst ? MakeReadOnlySpanOfT(elementType) : MakeSpanOfT(elementType)).WithTrailingTrivia(TriviaList(Space)));
fixedBlocks.Add(VariableDeclaration(externParam.Type).AddVariables(
VariableDeclarator(localName.Identifier).WithInitializer(EqualsValueClause(origName))));
arguments[param.SequenceNumber - 1] = Argument(localName);
}

if (lengthParamUsedBy.TryGetValue(sizeParamIndex.Value, out int userIndex))
{
Expand All @@ -280,25 +290,16 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
leadingStatements.Add(IfStatement(
BinaryExpression(
SyntaxKind.NotEqualsExpression,
GetSpanLength(otherUserName),
GetSpanLength(origName)),
GetSpanLength(otherUserName, parameters[userIndex].Type is ArrayTypeSyntax),
GetSpanLength(origName, remainsRefType)),
ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentException))).WithArgumentList(ArgumentList()))));
}
else
{
lengthParamUsedBy.Add(sizeParamIndex.Value, param.SequenceNumber - 1);
}

if (externParam.Type is PointerTypeSyntax)
{
parameters[param.SequenceNumber - 1] = parameters[param.SequenceNumber - 1]
.WithType((isIn && isConst ? MakeReadOnlySpanOfT(elementType) : MakeSpanOfT(elementType)).WithTrailingTrivia(TriviaList(Space)));
fixedBlocks.Add(VariableDeclaration(externParam.Type).AddVariables(
VariableDeclarator(localName.Identifier).WithInitializer(EqualsValueClause(origName))));
arguments[param.SequenceNumber - 1] = Argument(localName);
}

ExpressionSyntax sizeArgExpression = GetSpanLength(origName);
ExpressionSyntax sizeArgExpression = GetSpanLength(origName, remainsRefType);
if (!(parameters[sizeParamIndex.Value].Type is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.IntKeyword } }))
{
sizeArgExpression = CastExpression(parameters[sizeParamIndex.Value].Type!, sizeArgExpression);
Expand All @@ -322,7 +323,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
leadingStatements.Add(IfStatement(
BinaryExpression(
SyntaxKind.LessThanExpression,
GetSpanLength(origName),
GetSpanLength(origName, false /* we've converted it to be a span */),
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(sizeConst.Value))),
ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentException))).WithArgumentList(ArgumentList()))));
}
Expand Down

0 comments on commit 24fb7be

Please sign in to comment.