Skip to content

Commit

Permalink
Add helper APIs for variable-length inline arrays
Browse files Browse the repository at this point in the history
Closes #387
  • Loading branch information
AArnott committed Jan 25, 2024
1 parent 9059d51 commit d80cf0f
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla

internal static VariableDeclarationSyntax VariableDeclaration(TypeSyntax type) => SyntaxFactory.VariableDeclaration(type.WithTrailingTrivia(TriviaList(Space)));

internal static SizeOfExpressionSyntax SizeOfExpression(TypeSyntax type) => SyntaxFactory.SizeOfExpression(Token(SyntaxKind.SizeOfKeyword), Token(SyntaxKind.OpenParenToken), type, Token(SyntaxKind.CloseParenToken));

internal static MemberAccessExpressionSyntax MemberAccessExpression(SyntaxKind kind, ExpressionSyntax expression, SimpleNameSyntax name) => SyntaxFactory.MemberAccessExpression(kind, expression, Token(GetMemberAccessExpressionOperatorTokenKind(kind)), name);

internal static ConditionalAccessExpressionSyntax ConditionalAccessExpression(ExpressionSyntax expression, SimpleNameSyntax name) => SyntaxFactory.ConditionalAccessExpression(expression, Token(SyntaxKind.QuestionToken), MemberBindingExpression(name));
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.Windows.CsWin32/Generator.Features.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ public partial class Generator
private readonly bool canUseSpan;
private readonly bool canCallCreateSpan;
private readonly bool canUseUnsafeAsRef;
private readonly bool canUseUnsafeAdd;
private readonly bool canUseUnsafeNullRef;
private readonly bool canUseUnmanagedCallersOnlyAttribute;
private readonly bool canUseSetLastPInvokeError;
Expand Down
148 changes: 148 additions & 0 deletions src/Microsoft.Windows.CsWin32/Generator.Struct.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle

// If the last field has the [FlexibleArray] attribute, we must disable marshaling since the struct
// is only ever valid when accessed via a pointer since the struct acts as a header of an arbitrarily-sized array.
FieldDefinitionHandle flexibleArrayFieldHandle = default;
MethodDeclarationSyntax? sizeOfMethod = null;
if (typeDef.GetFields().LastOrDefault() is FieldDefinitionHandle { IsNil: false } lastFieldHandle)
{
FieldDefinition lastField = this.Reader.GetFieldDefinition(lastFieldHandle);
if (MetadataUtilities.FindAttribute(this.Reader, lastField.GetCustomAttributes(), InteropDecorationNamespace, FlexibleArrayAttribute) is not null)
{
flexibleArrayFieldHandle = lastFieldHandle;
context = context with { AllowMarshaling = false };
}
}
Expand Down Expand Up @@ -80,6 +83,37 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
.WithArgumentList(BracketedArgumentList(SingletonSeparatedList(Argument(size)))))
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.UnsafeKeyword), Token(SyntaxKind.FixedKeyword));
}
else if (fieldDefHandle == flexibleArrayFieldHandle)
{
CustomAttributeHandleCollection fieldAttributes = fieldDef.GetCustomAttributes();
var fieldTypeInfo = (ArrayTypeHandleInfo)fieldDef.DecodeSignature(SignatureHandleProvider.Instance, null);
TypeSyntax fieldType = fieldTypeInfo.ElementType.ToTypeSyntax(typeSettings, GeneratingElement.StructMember, fieldAttributes).Type;

if (fieldType is PointerTypeSyntax or FunctionPointerTypeSyntax)
{
// These types are not allowed as generic type arguments (https://github.com/dotnet/runtime/issues/13627)
// so we have to generate a special nested struct dedicated to this type instead of using the generic type.
StructDeclarationSyntax helperStruct = this.DeclareVariableLengthInlineArrayHelper(context, fieldType);
additionalMembers = additionalMembers.Add(helperStruct);

field = FieldDeclaration(
VariableDeclaration(IdentifierName(helperStruct.Identifier.ValueText)))
.AddDeclarationVariables(fieldDeclarator)
.AddModifiers(TokenWithSpace(this.Visibility));
}
else
{
this.RequestVariableLengthInlineArrayHelper(context);
field = FieldDeclaration(
VariableDeclaration(
GenericName($"global::Windows.Win32.VariableLengthInlineArray")
.WithTypeArgumentList(TypeArgumentList().AddArguments(fieldType))))
.AddDeclarationVariables(fieldDeclarator)
.AddModifiers(TokenWithSpace(this.Visibility));
}

sizeOfMethod = this.DeclareSizeOfMethod(name, fieldType, typeSettings);
}
else
{
CustomAttributeHandleCollection fieldAttributes = fieldDef.GetCustomAttributes();
Expand Down Expand Up @@ -334,6 +368,12 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
}
}

// Add a SizeOf method, if there is a FlexibleArray field.
if (sizeOfMethod is not null)
{
members.Add(sizeOfMethod);
}

// Add the additional members, taking care to not introduce redundant declarations.
members.AddRange(additionalMembers.Where(c => c is not StructDeclarationSyntax cs || !members.OfType<StructDeclarationSyntax>().Any(m => m.Identifier.ValueText == cs.Identifier.ValueText)));

Expand Down Expand Up @@ -370,6 +410,95 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
return result;
}

private StructDeclarationSyntax DeclareVariableLengthInlineArrayHelper(Context context, TypeSyntax fieldType)
{
IdentifierNameSyntax firstElementFieldName = IdentifierName("e0");
List<MemberDeclarationSyntax> members = new();

// internal unsafe T e0;
members.Add(FieldDeclaration(VariableDeclaration(fieldType).AddVariables(VariableDeclarator(firstElementFieldName.Identifier)))
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.UnsafeKeyword)));

if (this.canUseUnsafeAdd)
{
////[MethodImpl(MethodImplOptions.AggressiveInlining)]
////get { fixed (int** p = &e0) return *(p + index); }
IdentifierNameSyntax pLocal = IdentifierName("p");
AccessorDeclarationSyntax getter = AccessorDeclaration(SyntaxKind.GetAccessorDeclaration)
.WithBody(Block().AddStatements(
FixedStatement(
VariableDeclaration(PointerType(fieldType)).AddVariables(
VariableDeclarator(pLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, firstElementFieldName)))),
ReturnStatement(PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, ParenthesizedExpression(BinaryExpression(SyntaxKind.AddExpression, pLocal, IdentifierName("index"))))))))
.AddAttributeLists(AttributeList().AddAttributes(MethodImpl(MethodImplOptions.AggressiveInlining)));

////[MethodImpl(MethodImplOptions.AggressiveInlining)]
////set { fixed (int** p = &e0) *(p + index) = value; }
AccessorDeclarationSyntax setter = AccessorDeclaration(SyntaxKind.SetAccessorDeclaration)
.WithBody(Block().AddStatements(
FixedStatement(
VariableDeclaration(PointerType(fieldType)).AddVariables(
VariableDeclarator(pLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, firstElementFieldName)))),
ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, ParenthesizedExpression(BinaryExpression(SyntaxKind.AddExpression, pLocal, IdentifierName("index")))),
IdentifierName("value"))))))
.AddAttributeLists(AttributeList().AddAttributes(MethodImpl(MethodImplOptions.AggressiveInlining)));

////internal unsafe T this[int index]
members.Add(IndexerDeclaration(fieldType.WithTrailingTrivia(Space))
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.UnsafeKeyword))
.AddParameterListParameters(Parameter(Identifier("index")).WithType(PredefinedType(TokenWithSpace(SyntaxKind.IntKeyword))))
.AddAccessorListAccessors(getter, setter));
}

// internal partial struct VariableLengthInlineArrayHelper
return StructDeclaration(Identifier("VariableLengthInlineArrayHelper"))
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.PartialKeyword))
.AddMembers(members.ToArray());
}

private MethodDeclarationSyntax DeclareSizeOfMethod(TypeSyntax structType, TypeSyntax elementType, TypeSyntaxSettings typeSettings)
{
PredefinedTypeSyntax intType = PredefinedType(TokenWithSpace(SyntaxKind.IntKeyword));
IdentifierNameSyntax countName = IdentifierName("count");
IdentifierNameSyntax localName = IdentifierName("v");
List<StatementSyntax> statements = new();

// int v = sizeof(OUTER_STRUCT);
statements.Add(LocalDeclarationStatement(VariableDeclaration(intType).AddVariables(
VariableDeclarator(localName.Identifier).WithInitializer(EqualsValueClause(SizeOfExpression(structType))))));

// if (count > 1)
// v += checked((count - 1) * sizeof(ELEMENT_TYPE));
// else if (count < 0)
// throw new ArgumentOutOfRangeException(nameof(count));
statements.Add(IfStatement(
BinaryExpression(SyntaxKind.GreaterThanExpression, countName, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(1))),
ExpressionStatement(AssignmentExpression(
SyntaxKind.AddAssignmentExpression,
localName,
CheckedExpression(BinaryExpression(
SyntaxKind.MultiplyExpression,
ParenthesizedExpression(BinaryExpression(SyntaxKind.SubtractExpression, countName, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(1)))),
SizeOfExpression(elementType))))),
ElseClause(IfStatement(
BinaryExpression(SyntaxKind.LessThanExpression, countName, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))),
ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentOutOfRangeException))))).WithCloseParenToken(TokenWithLineFeed(SyntaxKind.CloseParenToken)))).WithCloseParenToken(TokenWithLineFeed(SyntaxKind.CloseParenToken)));

// return v;
statements.Add(ReturnStatement(localName));

// internal static unsafe int SizeOf(int count)
MethodDeclarationSyntax sizeOfMethod = MethodDeclaration(intType, Identifier("SizeOf"))
.AddParameterListParameters(Parameter(countName.Identifier).WithType(intType))
.WithBody(Block().AddStatements(statements.ToArray()))
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.UnsafeKeyword))
.WithLeadingTrivia(ParseLeadingTrivia("/// <summary>Computes the amount of memory that must be allocated to store this struct, including the specified number of elements in the variable length inline array at the end.</summary>\n"));

return sizeOfMethod;
}

private (TypeSyntax FieldType, SyntaxList<MemberDeclarationSyntax> AdditionalMembers, AttributeSyntax? MarshalAsAttribute) ReinterpretFieldType(FieldDefinition fieldDef, TypeSyntax originalType, CustomAttributeHandleCollection customAttributes, Context context)
{
TypeSyntaxSettings typeSettings = context.Filter(this.fieldTypeSettings);
Expand Down Expand Up @@ -397,4 +526,23 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle

return (originalType, default(SyntaxList<MemberDeclarationSyntax>), marshalAs);
}

private void RequestVariableLengthInlineArrayHelper(Context context)
{
if (this.IsWin32Sdk)
{
if (!this.IsTypeAlreadyFullyDeclared($"{this.Namespace}.{this.variableLengthInlineArrayStruct.Identifier.ValueText}"))
{
this.DeclareUnscopedRefAttributeIfNecessary();
this.volatileCode.GenerateSpecialType("VariableLengthInlineArray", () => this.volatileCode.AddSpecialType("VariableLengthInlineArray", this.variableLengthInlineArrayStruct));
}
}
else if (this.SuperGenerator is not null && this.SuperGenerator.TryGetGenerator("Windows.Win32", out Generator? generator))
{
generator.volatileCode.GenerationTransaction(delegate
{
generator.RequestVariableLengthInlineArrayHelper(context);
});
}
}
}
6 changes: 5 additions & 1 deletion src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public partial class Generator : IGenerator, IDisposable
private readonly TypeSyntaxSettings errorMessageTypeSettings;

private readonly ClassDeclarationSyntax comHelperClass;
private readonly StructDeclarationSyntax variableLengthInlineArrayStruct;

private readonly Dictionary<string, IReadOnlyList<ISymbol>> findTypeSymbolIfAlreadyAvailableCache = new(StringComparer.Ordinal);
private readonly Rental<MetadataReader> metadataReader;
Expand Down Expand Up @@ -86,7 +87,8 @@ public Generator(string metadataLibraryPath, Docs? docs, GeneratorOptions option

this.canUseSpan = this.compilation?.GetTypeByMetadataName(typeof(Span<>).FullName) is not null;
this.canCallCreateSpan = this.compilation?.GetTypeByMetadataName(typeof(MemoryMarshal).FullName)?.GetMembers("CreateSpan").Any() is true;
this.canUseUnsafeAsRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("AsRef").Any() is true;
this.canUseUnsafeAsRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("Add").Any() is true;
this.canUseUnsafeAdd = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("AsRef").Any() is true;
this.canUseUnsafeNullRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("NullRef").Any() is true;
this.canUseUnmanagedCallersOnlyAttribute = this.compilation?.GetTypeByMetadataName("System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute") is not null;
this.canUseSetLastPInvokeError = this.compilation?.GetTypeByMetadataName("System.Runtime.InteropServices.Marshal")?.GetMembers("GetLastSystemError").IsEmpty is false;
Expand All @@ -110,6 +112,7 @@ public Generator(string metadataLibraryPath, Docs? docs, GeneratorOptions option
AddSymbolIf(this.canUseSpan, "canUseSpan");
AddSymbolIf(this.canCallCreateSpan, "canCallCreateSpan");
AddSymbolIf(this.canUseUnsafeAsRef, "canUseUnsafeAsRef");
AddSymbolIf(this.canUseUnsafeAdd, "canUseUnsafeAdd");
AddSymbolIf(this.canUseUnsafeNullRef, "canUseUnsafeNullRef");
AddSymbolIf(compilation?.GetTypeByMetadataName("System.Drawing.Point") is not null, "canUseSystemDrawing");
AddSymbolIf(this.IsFeatureAvailable(Feature.InterfaceStaticMembers), "canUseInterfaceStaticMembers");
Expand Down Expand Up @@ -149,6 +152,7 @@ void AddSymbolIf(bool condition, string symbol)
this.methodsAndConstantsClassName = IdentifierName(options.ClassName);

FetchTemplate("ComHelpers", this, out this.comHelperClass);
FetchTemplate("VariableLengthInlineArray`1", this, out this.variableLengthInlineArrayStruct);
}

internal enum GeneratingElement
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
internal struct VariableLengthInlineArray<T>
where T : unmanaged
{
internal T e0;

#if canUseUnsafeAdd
internal ref T this[int index]
{
[UnscopedRef]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => ref Unsafe.Add(ref this.e0, index);
}
#endif

#if canUseSpan
[UnscopedRef]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal Span<T> AsSpan(int length)
{
#if canCallCreateSpan
return MemoryMarshal.CreateSpan(ref this.e0, length);
#else
unsafe
{
fixed (void* p = &this.e0)
{
return new Span<T>(p, length);
}
}
#endif
}
#endif
}
37 changes: 37 additions & 0 deletions test/GenerationSandbox.Tests/FlexibleArrayTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Runtime.InteropServices;
using Windows.Win32.System.Ole;

public class FlexibleArrayTests
{
[Fact]
public unsafe void FlexibleArraySizing()
{
const int count = 3;
PAGESET* pPageSet = (PAGESET*)Marshal.AllocHGlobal(PAGESET.SizeOf(count));
try
{
pPageSet->rgPages[0].nFromPage = 0;

Span<PAGERANGE> pageRange = pPageSet->rgPages.AsSpan(count);
for (int i = 0; i < count; i++)
{
pageRange[i].nFromPage = i * 2;
pageRange[i].nToPage = (i * 2) + 1;
}
}
finally
{
Marshal.FreeHGlobal((IntPtr)pPageSet);
}
}

[Fact]
public void SizeOf_Minimum1Element()
{
Assert.Equal(PAGESET.SizeOf(1), PAGESET.SizeOf(0));
Assert.Equal(Marshal.SizeOf<PAGERANGE>(), PAGESET.SizeOf(2) - PAGESET.SizeOf(1));
}
}
1 change: 1 addition & 0 deletions test/GenerationSandbox.Tests/NativeMethods.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ MAKELRESULT
MAKEWPARAM
MAX_PATH
NTSTATUS
PAGESET
PathParseIconLocation
PROCESS_BASIC_INFORMATION
PZZSTR
Expand Down
16 changes: 16 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/StructTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,22 @@ public void StructConstantsAreGeneratedAsConstants()
Assert.NotEmpty(type.Members.OfType<FieldDeclarationSyntax>().Where(f => f.Modifiers.Any(SyntaxKind.ConstKeyword)));
}

[Theory]
[MemberData(nameof(TFMData))]
public void FlexibleArrayMember(string tfm)
{
this.compilation = this.starterCompilations[tfm];
this.GenerateApi("BITMAPINFO");
var type = (StructDeclarationSyntax)Assert.Single(this.FindGeneratedType("BITMAPINFO"));
FieldDeclarationSyntax flexArrayField = Assert.Single(type.Members.OfType<FieldDeclarationSyntax>(), m => m.Declaration.Variables.Any(v => v.Identifier.ValueText == "bmiColors"));
var fieldType = Assert.IsType<GenericNameSyntax>(Assert.IsType<QualifiedNameSyntax>(flexArrayField.Declaration.Type).Right);
Assert.Equal("VariableLengthInlineArray", fieldType.Identifier.ValueText);
Assert.Equal("RGBQUAD", Assert.IsType<QualifiedNameSyntax>(Assert.Single(fieldType.TypeArgumentList.Arguments)).Right.Identifier.ValueText);

// Verify that the SizeOf method was generated.
Assert.Single(this.FindGeneratedMethod("SizeOf"));
}

[Theory]
[CombinatorialData]
public void InterestingStructs(
Expand Down

0 comments on commit d80cf0f

Please sign in to comment.