diff --git a/DllImportGenerator/Ancillary.Interop/Ancillary.Interop.csproj b/DllImportGenerator/Ancillary.Interop/Ancillary.Interop.csproj index 536eec52ae8f..e9063db196be 100644 --- a/DllImportGenerator/Ancillary.Interop/Ancillary.Interop.csproj +++ b/DllImportGenerator/Ancillary.Interop/Ancillary.Interop.csproj @@ -4,6 +4,7 @@ net5.0 8.0 System.Runtime.InteropServices + enable diff --git a/DllImportGenerator/Ancillary.Interop/GeneratedDllImportAttribute.cs b/DllImportGenerator/Ancillary.Interop/GeneratedDllImportAttribute.cs index 0f0d322da4dd..42c15094915a 100644 --- a/DllImportGenerator/Ancillary.Interop/GeneratedDllImportAttribute.cs +++ b/DllImportGenerator/Ancillary.Interop/GeneratedDllImportAttribute.cs @@ -1,5 +1,4 @@ -#nullable enable - + namespace System.Runtime.InteropServices { // [TODO] Remove once the attribute has been added to the BCL diff --git a/DllImportGenerator/Ancillary.Interop/MarshalEx.cs b/DllImportGenerator/Ancillary.Interop/MarshalEx.cs new file mode 100644 index 000000000000..95be96d86858 --- /dev/null +++ b/DllImportGenerator/Ancillary.Interop/MarshalEx.cs @@ -0,0 +1,29 @@ + +using System.Reflection; + +namespace System.Runtime.InteropServices +{ + /// + /// Marshalling helper methods that will likely live in S.R.IS.Marshal + /// when we integrate our APIs with dotnet/runtime. + /// + public static class MarshalEx + { + public static TSafeHandle CreateSafeHandle() + where TSafeHandle : SafeHandle + { + if (typeof(TSafeHandle).IsAbstract || typeof(TSafeHandle).GetConstructor(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.CreateInstance | BindingFlags.Instance, null, Type.EmptyTypes, null) == null) + { + throw new MissingMemberException($"The safe handle type '{typeof(TSafeHandle).FullName}' must be a non-abstract type with a parameterless constructor."); + } + + TSafeHandle safeHandle = (TSafeHandle)Activator.CreateInstance(typeof(TSafeHandle), nonPublic: true)!; + return safeHandle; + } + + public static void SetHandle(SafeHandle safeHandle, IntPtr handle) + { + typeof(SafeHandle).GetMethod("SetHandle", BindingFlags.NonPublic | BindingFlags.Instance)!.Invoke(safeHandle, new object[] { handle }); + } + } +} diff --git a/DllImportGenerator/Demo/Demo.csproj b/DllImportGenerator/Demo/Demo.csproj index 4cee92efcd88..fdba3fbbb354 100644 --- a/DllImportGenerator/Demo/Demo.csproj +++ b/DllImportGenerator/Demo/Demo.csproj @@ -16,4 +16,8 @@ + + + + diff --git a/DllImportGenerator/Demo/Program.cs b/DllImportGenerator/Demo/Program.cs index ff7d931ccd3c..06715a76c7c0 100644 --- a/DllImportGenerator/Demo/Program.cs +++ b/DllImportGenerator/Demo/Program.cs @@ -31,6 +31,13 @@ static void Main(string[] args) c = b; NativeExportsNE.Sum(a, ref c); Console.WriteLine($"{a} + {b} = {c}"); + + SafeHandleTests tests = new SafeHandleTests(); + + tests.ReturnValue_CreatesSafeHandle(); + tests.ByValue_CorrectlyUnwrapsHandle(); + tests.ByRefSameValue_UsesSameHandleInstance(); + tests.ByRefDifferentValue_UsesNewHandleInstance(); } } } diff --git a/DllImportGenerator/Demo/SafeHandleTests.cs b/DllImportGenerator/Demo/SafeHandleTests.cs new file mode 100644 index 000000000000..bb21cfb565ca --- /dev/null +++ b/DllImportGenerator/Demo/SafeHandleTests.cs @@ -0,0 +1,74 @@ + +using System.Runtime.InteropServices; +using Microsoft.Win32.SafeHandles; +using Xunit; + +namespace Demo +{ + partial class NativeExportsNE + { + public class NativeExportsSafeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + private NativeExportsSafeHandle() : base(true) + { + } + + protected override bool ReleaseHandle() + { + Assert.True(NativeExportsNE.ReleaseHandle(handle)); + return true; + } + } + + [GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "alloc_handle")] + public static partial NativeExportsSafeHandle AllocateHandle(); + + [GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "release_handle")] + [return:MarshalAs(UnmanagedType.I1)] + private static partial bool ReleaseHandle(nint handle); + + [GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "is_handle_alive")] + [return:MarshalAs(UnmanagedType.I1)] + public static partial bool IsHandleAlive(NativeExportsSafeHandle handle); + + [GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "modify_handle")] + public static partial void ModifyHandle(ref NativeExportsSafeHandle handle, [MarshalAs(UnmanagedType.I1)] bool newHandle); + } + + public class SafeHandleTests + { + [Fact] + public void ReturnValue_CreatesSafeHandle() + { + using NativeExportsNE.NativeExportsSafeHandle handle = NativeExportsNE.AllocateHandle(); + Assert.False(handle.IsClosed); + Assert.False(handle.IsInvalid); + } + + [Fact] + public void ByValue_CorrectlyUnwrapsHandle() + { + using NativeExportsNE.NativeExportsSafeHandle handle = NativeExportsNE.AllocateHandle(); + Assert.True(NativeExportsNE.IsHandleAlive(handle)); + } + + [Fact] + public void ByRefSameValue_UsesSameHandleInstance() + { + using NativeExportsNE.NativeExportsSafeHandle handleToDispose = NativeExportsNE.AllocateHandle(); + NativeExportsNE.NativeExportsSafeHandle handle = handleToDispose; + NativeExportsNE.ModifyHandle(ref handle, false); + Assert.Same(handleToDispose, handle); + } + + [Fact] + public void ByRefDifferentValue_UsesNewHandleInstance() + { + using NativeExportsNE.NativeExportsSafeHandle handleToDispose = NativeExportsNE.AllocateHandle(); + NativeExportsNE.NativeExportsSafeHandle handle = handleToDispose; + NativeExportsNE.ModifyHandle(ref handle, true); + Assert.NotSame(handleToDispose, handle); + handle.Dispose(); + } + } +} \ No newline at end of file diff --git a/DllImportGenerator/Directory.Build.props b/DllImportGenerator/Directory.Build.props index 81194ffe87da..61ba7c1afc32 100644 --- a/DllImportGenerator/Directory.Build.props +++ b/DllImportGenerator/Directory.Build.props @@ -2,6 +2,7 @@ 3.8.0-3.final + 2.4.1 diff --git a/DllImportGenerator/DllImportGenerator.Test/Compiles.cs b/DllImportGenerator/DllImportGenerator.Test/Compiles.cs index 33b741e8ad25..3c66f36bbe6d 100644 --- a/DllImportGenerator/DllImportGenerator.Test/Compiles.cs +++ b/DllImportGenerator/DllImportGenerator.Test/Compiles.cs @@ -90,6 +90,7 @@ public static IEnumerable CodeSnippetsToCompile() yield return new[] { CodeSnippets.DelegateMarshalAsParametersAndModifiers }; yield return new[] { CodeSnippets.BlittableStructParametersAndModifiers }; yield return new[] { CodeSnippets.GenericBlittableStructParametersAndModifiers }; + yield return new[] { CodeSnippets.BasicParametersAndModifiers("Microsoft.Win32.SafeHandles.SafeFileHandle") }; } [Theory] diff --git a/DllImportGenerator/DllImportGenerator.Test/DllImportGenerator.Test.csproj b/DllImportGenerator/DllImportGenerator.Test/DllImportGenerator.Test.csproj index 75fb2cb381c8..cd384172bce4 100644 --- a/DllImportGenerator/DllImportGenerator.Test/DllImportGenerator.Test.csproj +++ b/DllImportGenerator/DllImportGenerator.Test/DllImportGenerator.Test.csproj @@ -15,7 +15,7 @@ runtime; build; native; contentfiles; analyzers; buildtransitive - + runtime; build; native; contentfiles; analyzers; buildtransitive all diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs index 2e50dcc2df0d..60ef933582a8 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs @@ -64,6 +64,7 @@ internal class MarshallingGenerators public static readonly Forwarder Forwarder = new Forwarder(); public static readonly BlittableMarshaller Blittable = new BlittableMarshaller(); public static readonly DelegateMarshaller Delegate = new DelegateMarshaller(); + public static readonly SafeHandleMarshaller SafeHandle = new SafeHandleMarshaller(); public static bool TryCreate(TypePositionInfo info, StubCodeContext context, out IMarshallingGenerator generator) { @@ -126,6 +127,10 @@ public static bool TryCreate(TypePositionInfo info, StubCodeContext context, out generator = Forwarder; return false; + case { MarshallingAttributeInfo: SafeHandleMarshallingInfo _}: + generator = SafeHandle; + return true; + default: generator = Forwarder; return false; diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs new file mode 100644 index 000000000000..041472a7446a --- /dev/null +++ b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs @@ -0,0 +1,230 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Microsoft.Interop +{ + class SafeHandleMarshaller : IMarshallingGenerator + { + public TypeSyntax AsNativeType(TypePositionInfo info) + { + return ParseTypeName("global::System.IntPtr"); + } + + public ParameterSyntax AsParameter(TypePositionInfo info) + { + var type = info.IsByRef + ? PointerType(AsNativeType(info)) + : AsNativeType(info); + return Parameter(Identifier(info.InstanceIdentifier)) + .WithType(type); + } + + public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) + { + string identifier = context.GetIdentifiers(info).native; + if (info.IsByRef) + { + return Argument( + PrefixUnaryExpression( + SyntaxKind.AddressOfExpression, + IdentifierName(identifier))); + } + + return Argument(IdentifierName(identifier)); + } + + public IEnumerable Generate(TypePositionInfo info, StubCodeContext context) + { + // The high level logic (note that the parameter may be in, out or both): + // 1) If this is an input parameter we need to AddRef the SafeHandle. + // 2) If this is an output parameter we need to preallocate a SafeHandle to wrap the new native handle value. We + // must allocate this before the native call to avoid a failure point when we already have a native resource + // allocated. We must allocate a new SafeHandle even if we have one on input since both input and output native + // handles need to be tracked and released by a SafeHandle. + // 3) Initialize a local IntPtr that will be passed to the native call. If we have an input SafeHandle the value + // comes from there otherwise we get it from the new SafeHandle (which is guaranteed to be initialized to an + // invalid handle value). + // 4) If this is a out parameter we also store the original handle value (that we just computed above) in a local + // variable. + // 5) If we successfully AddRef'd the incoming SafeHandle, we need to Release it before we return. + // 6) After the native call, if this is an output parameter and the handle value we passed to native differs from + // the local copy we made then the new handle value is written into the output SafeHandle and that SafeHandle + // is propagated back to the caller. + + (string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info); + string addRefdIdentifier = $"{managedIdentifier}__addRefd"; + string newHandleObjectIdentifier = info.IsManagedReturnPosition + ? managedIdentifier + : $"{managedIdentifier}__newHandle"; + string handleValueBackupIdentifier = $"{nativeIdentifier}__original"; + switch (context.CurrentStage) + { + case StubCodeContext.Stage.Setup: + yield return LocalDeclarationStatement( + VariableDeclaration( + AsNativeType(info), + SingletonSeparatedList( + VariableDeclarator(nativeIdentifier)))); + if (!info.IsManagedReturnPosition && info.RefKind != RefKind.Out) + { + yield return LocalDeclarationStatement( + VariableDeclaration( + PredefinedType(Token(SyntaxKind.BoolKeyword)), + SingletonSeparatedList( + VariableDeclarator(addRefdIdentifier) + .WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.FalseLiteralExpression)))))); + + } + if (info.IsByRef && info.RefKind != RefKind.In) + { + // We create the new handle in the Setup phase + // so we eliminate the possible failure points during unmarshalling, where we would + // leak the handle if we failed to create the handle. + yield return LocalDeclarationStatement( + VariableDeclaration( + info.ManagedType.AsTypeSyntax(), + SingletonSeparatedList( + VariableDeclarator(newHandleObjectIdentifier) + .WithInitializer(EqualsValueClause( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx), + GenericName(Identifier("CreateSafeHandle"), + TypeArgumentList(SingletonSeparatedList(info.ManagedType.AsTypeSyntax())))), + ArgumentList())))))); + yield return LocalDeclarationStatement( + VariableDeclaration( + AsNativeType(info), + SingletonSeparatedList( + VariableDeclarator(handleValueBackupIdentifier) + .WithInitializer(EqualsValueClause( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(newHandleObjectIdentifier), + IdentifierName(nameof(SafeHandle.DangerousGetHandle))), + ArgumentList())))))); + } + else if (info.IsManagedReturnPosition) + { + yield return ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(managedIdentifier), + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx), + GenericName(Identifier("CreateSafeHandle"), + TypeArgumentList(SingletonSeparatedList(info.ManagedType.AsTypeSyntax())))), + ArgumentList()))); + } + break; + case StubCodeContext.Stage.Marshal: + if (info.RefKind != RefKind.Out) + { + // .DangerousAddRef(ref ); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(managedIdentifier), + IdentifierName(nameof(SafeHandle.DangerousAddRef))), + ArgumentList(SingletonSeparatedList( + Argument(IdentifierName(addRefdIdentifier)) + .WithRefKindKeyword(Token(SyntaxKind.RefKeyword)))))); + + + ExpressionSyntax assignHandleToNativeExpression = + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(nativeIdentifier), + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(managedIdentifier), + IdentifierName(nameof(SafeHandle.DangerousGetHandle))), + ArgumentList())); + if (info.IsByRef && info.RefKind != RefKind.In) + { + yield return ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(handleValueBackupIdentifier), + assignHandleToNativeExpression)); + } + else + { + yield return ExpressionStatement(assignHandleToNativeExpression); + } + } + break; + case StubCodeContext.Stage.GuaranteedUnmarshal: + StatementSyntax unmarshalStatement = ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName(TypeNames.System_Runtime_InteropServices_MarshalEx), + IdentifierName("SetHandle")), + ArgumentList(SeparatedList( + new [] + { + Argument(IdentifierName(newHandleObjectIdentifier)), + Argument(IdentifierName(nativeIdentifier)) + })))); + + if(info.IsManagedReturnPosition) + { + yield return unmarshalStatement; + } + else if (info.RefKind == RefKind.Out) + { + yield return unmarshalStatement; + yield return ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(managedIdentifier), + IdentifierName(newHandleObjectIdentifier))); + } + else if (info.RefKind == RefKind.Ref) + { + // Decrement refcount on original SafeHandle if we addrefd + yield return IfStatement( + IdentifierName(addRefdIdentifier), + ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(managedIdentifier), + IdentifierName(nameof(SafeHandle.DangerousRelease))), + ArgumentList()))); + + // Do not unmarshal the handle if the value didn't change. + yield return IfStatement( + BinaryExpression(SyntaxKind.NotEqualsExpression, + IdentifierName(handleValueBackupIdentifier), + IdentifierName(nativeIdentifier)), + Block( + unmarshalStatement, + ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(managedIdentifier), + IdentifierName(newHandleObjectIdentifier))))); + } + break; + case StubCodeContext.Stage.Cleanup: + if (!info.IsByRef || info.RefKind == RefKind.In) + { + yield return IfStatement( + IdentifierName(addRefdIdentifier), + ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(managedIdentifier), + IdentifierName(nameof(SafeHandle.DangerousRelease))), + ArgumentList()))); + } + break; + default: + break; + } + } + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true; + } +} diff --git a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs index a4736aa155a8..0462c8cdd191 100644 --- a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs +++ b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs @@ -12,7 +12,7 @@ namespace Microsoft.Interop // for C# 10 discriminated unions. Once discriminated unions are released, // these should be updated to be implemented as a discriminated union. - internal abstract record MarshallingAttributeInfo {} + internal abstract record MarshallingInfo {} /// /// User-applied System.Runtime.InteropServices.MarshalAsAttribute @@ -23,14 +23,14 @@ internal sealed record MarshalAsInfo( string? CustomMarshallerCookie, UnmanagedType UnmanagedArraySubType, int ArraySizeConst, - short ArraySizeParamIndex) : MarshallingAttributeInfo; + short ArraySizeParamIndex) : MarshallingInfo; /// /// User-applied System.Runtime.InteropServices.BlittableTypeAttribute /// or System.Runtime.InteropServices.GeneratedMarshallingAttribute on a blittable type /// in source in this compilation. /// - internal sealed record BlittableTypeAttributeInfo : MarshallingAttributeInfo; + internal sealed record BlittableTypeAttributeInfo : MarshallingInfo; [Flags] internal enum SupportedMarshallingMethods @@ -47,12 +47,18 @@ internal enum SupportedMarshallingMethods internal sealed record NativeMarshallingAttributeInfo( ITypeSymbol NativeMarshallingType, ITypeSymbol? ValuePropertyType, - SupportedMarshallingMethods MarshallingMethods) : MarshallingAttributeInfo; + SupportedMarshallingMethods MarshallingMethods) : MarshallingInfo; /// /// User-applied System.Runtime.InteropServices.GeneratedMarshallingAttribute /// on a non-blittable type in source in this compilation. /// internal sealed record GeneratedNativeMarshallingAttributeInfo( - string NativeMarshallingFullyQualifiedTypeName) : MarshallingAttributeInfo; + string NativeMarshallingFullyQualifiedTypeName) : MarshallingInfo; + + /// + /// The type of the element is a SafeHandle-derived type with no marshalling attributes. + /// + internal sealed record SafeHandleMarshallingInfo : MarshallingInfo; + } diff --git a/DllImportGenerator/DllImportGenerator/StubCodeContext.cs b/DllImportGenerator/DllImportGenerator/StubCodeContext.cs index 10c254234fcf..0b5f9fe0a862 100644 --- a/DllImportGenerator/DllImportGenerator/StubCodeContext.cs +++ b/DllImportGenerator/DllImportGenerator/StubCodeContext.cs @@ -46,7 +46,13 @@ public enum Stage /// /// Keep alive any managed objects that need to stay alive across the call. /// - KeepAlive + KeepAlive, + + /// + /// Convert native data to managed data even in the case of an exception during + /// the non-cleanup phases. + /// + GuaranteedUnmarshal } public Stage CurrentStage { get; protected set; } diff --git a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs index e7e870bf888f..043819bc068f 100644 --- a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs @@ -141,6 +141,7 @@ public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSynt Stage.Invoke, Stage.KeepAlive, Stage.Unmarshal, + Stage.GuaranteedUnmarshal, Stage.Cleanup }; @@ -151,7 +152,7 @@ public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSynt int initialCount = statements.Count; context.CurrentStage = stage; - if (!invokeReturnsVoid && (stage == Stage.Setup || stage == Stage.Unmarshal)) + if (!invokeReturnsVoid && (stage == Stage.Setup || stage == Stage.Unmarshal || stage == Stage.GuaranteedUnmarshal)) { // Handle setup and unmarshalling for return var retStatements = retMarshaller.Generator.Generate(retMarshaller.TypeInfo, context); diff --git a/DllImportGenerator/DllImportGenerator/TypeNames.cs b/DllImportGenerator/DllImportGenerator/TypeNames.cs index 27481eb846e8..284aa2a3694e 100644 --- a/DllImportGenerator/DllImportGenerator/TypeNames.cs +++ b/DllImportGenerator/DllImportGenerator/TypeNames.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using System.Text; -namespace DllImportGenerator +namespace Microsoft.Interop { static class TypeNames { @@ -19,5 +19,9 @@ static class TypeNames public const string System_Runtime_InteropServices_StructLayoutAttribute = "System.Runtime.InteropServices.StructLayoutAttribute"; public const string System_Runtime_InteropServices_MarshalAsAttribute = "System.Runtime.InteropServices.MarshalAsAttribute"; + + public const string System_Runtime_InteropServices_MarshalEx = "System.Runtime.InteropServices.MarshalEx"; + + public const string System_Runtime_InteropServices_SafeHandle = "System.Runtime.InteropServices.SafeHandle"; } } diff --git a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs index 17badcfb3405..e816f8457d79 100644 --- a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs +++ b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs @@ -39,11 +39,11 @@ private TypePositionInfo() public int NativeIndex { get; set; } public int UnmanagedLCIDConversionArgIndex { get; private set; } - public MarshallingAttributeInfo MarshallingAttributeInfo { get; private set; } + public MarshallingInfo MarshallingAttributeInfo { get; private set; } public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, Compilation compilation) { - var marshallingInfo = GetMarshallingAttributeInfo(paramSymbol.Type, paramSymbol.GetAttributes(), compilation); + var marshallingInfo = GetMarshallingInfo(paramSymbol.Type, paramSymbol.GetAttributes(), compilation); var typeInfo = new TypePositionInfo() { ManagedType = paramSymbol.Type, @@ -58,7 +58,7 @@ public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable attributes, Compilation compilation) { - var marshallingInfo = GetMarshallingAttributeInfo(type, attributes, compilation); + var marshallingInfo = GetMarshallingInfo(type, attributes, compilation); var typeInfo = new TypePositionInfo() { ManagedType = type, @@ -72,9 +72,9 @@ public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable attributes, Compilation compilation) + private static MarshallingInfo? GetMarshallingInfo(ITypeSymbol type, IEnumerable attributes, Compilation compilation) { - MarshallingAttributeInfo? marshallingInfo = null; + MarshallingInfo? marshallingInfo = null; // Look at attributes on the type. foreach (var attrData in attributes) { @@ -134,6 +134,11 @@ public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable ActiveHandles = new HashSet(); + + [UnmanagedCallersOnly(EntryPoint = "alloc_handle")] + public static nint AllocateHandle() + { + return AllocateHandleCore(); + } + + private static nint AllocateHandleCore() + { + if (LastHandle == int.MaxValue) + { + return InvalidHandle; + } + + nint newHandle = ++LastHandle; + ActiveHandles.Add(newHandle); + return newHandle; + } + + [UnmanagedCallersOnly(EntryPoint = "release_handle")] + public static byte ReleaseHandle(nint handle) + { + return ActiveHandles.Remove(handle) ? 1 : 0; + } + + [UnmanagedCallersOnly(EntryPoint = "is_handle_alive")] + public static byte IsHandleAlive(nint handle) + { + return ActiveHandles.Contains(handle) ? 1 : 0; + } + + [UnmanagedCallersOnly(EntryPoint = "modify_handle")] + public static void ModifyHandle(nint* handle, byte newHandle) + { + if (newHandle != 0) + { + *handle = AllocateHandleCore(); + } + } + } +} \ No newline at end of file diff --git a/DllImportGenerator/designs/Pipeline.md b/DllImportGenerator/designs/Pipeline.md index d42d3a08e4d5..d34be0159c25 100644 --- a/DllImportGenerator/designs/Pipeline.md +++ b/DllImportGenerator/designs/Pipeline.md @@ -38,9 +38,13 @@ Generation of the stub code happens in stages. The marshalling generator for eac 1. `Invoke`: call to the generated P/Invoke - Call `AsArgument` on the marshalling generator for every parameter - Create invocation statement that calls the generated P/Invoke +1. `KeepAlive`: keep alive any objects who's native representation won't keep them alive across the call. + - Call `Generate` on the marshalling generator for every parameter. 1. `Unmarshal`: conversion of native to managed data - If the method has a non-void return, call `Generate` on the marshalling generator for the return - Call `Generate` on the marshalling generator for every parameter +1. `GuaranteedUnmarshal`: conversion of native to managed data even when an exception is thrown + - Call `Generate` on the marshalling generator for every parameter. 1. `Cleanup`: free any allocated resources - Call `Generate` on the marshalling generator for every parameter