diff --git a/DllImportGenerator/Ancillary.Interop/Ancillary.Interop.csproj b/DllImportGenerator/Ancillary.Interop/Ancillary.Interop.csproj
index 536eec52ae8f..e9063db196be 100644
--- a/DllImportGenerator/Ancillary.Interop/Ancillary.Interop.csproj
+++ b/DllImportGenerator/Ancillary.Interop/Ancillary.Interop.csproj
@@ -4,6 +4,7 @@
net5.0
8.0
System.Runtime.InteropServices
+ enable
diff --git a/DllImportGenerator/Ancillary.Interop/GeneratedDllImportAttribute.cs b/DllImportGenerator/Ancillary.Interop/GeneratedDllImportAttribute.cs
index 0f0d322da4dd..42c15094915a 100644
--- a/DllImportGenerator/Ancillary.Interop/GeneratedDllImportAttribute.cs
+++ b/DllImportGenerator/Ancillary.Interop/GeneratedDllImportAttribute.cs
@@ -1,5 +1,4 @@
-#nullable enable
-
+
namespace System.Runtime.InteropServices
{
// [TODO] Remove once the attribute has been added to the BCL
diff --git a/DllImportGenerator/Ancillary.Interop/MarshalEx.cs b/DllImportGenerator/Ancillary.Interop/MarshalEx.cs
new file mode 100644
index 000000000000..95be96d86858
--- /dev/null
+++ b/DllImportGenerator/Ancillary.Interop/MarshalEx.cs
@@ -0,0 +1,29 @@
+
+using System.Reflection;
+
+namespace System.Runtime.InteropServices
+{
+ ///
+ /// Marshalling helper methods that will likely live in S.R.IS.Marshal
+ /// when we integrate our APIs with dotnet/runtime.
+ ///
+ public static class MarshalEx
+ {
+ public static TSafeHandle CreateSafeHandle()
+ where TSafeHandle : SafeHandle
+ {
+ if (typeof(TSafeHandle).IsAbstract || typeof(TSafeHandle).GetConstructor(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.CreateInstance | BindingFlags.Instance, null, Type.EmptyTypes, null) == null)
+ {
+ throw new MissingMemberException($"The safe handle type '{typeof(TSafeHandle).FullName}' must be a non-abstract type with a parameterless constructor.");
+ }
+
+ TSafeHandle safeHandle = (TSafeHandle)Activator.CreateInstance(typeof(TSafeHandle), nonPublic: true)!;
+ return safeHandle;
+ }
+
+ public static void SetHandle(SafeHandle safeHandle, IntPtr handle)
+ {
+ typeof(SafeHandle).GetMethod("SetHandle", BindingFlags.NonPublic | BindingFlags.Instance)!.Invoke(safeHandle, new object[] { handle });
+ }
+ }
+}
diff --git a/DllImportGenerator/Demo/Demo.csproj b/DllImportGenerator/Demo/Demo.csproj
index 4cee92efcd88..fdba3fbbb354 100644
--- a/DllImportGenerator/Demo/Demo.csproj
+++ b/DllImportGenerator/Demo/Demo.csproj
@@ -16,4 +16,8 @@
+
+
+
+
diff --git a/DllImportGenerator/Demo/Program.cs b/DllImportGenerator/Demo/Program.cs
index ff7d931ccd3c..06715a76c7c0 100644
--- a/DllImportGenerator/Demo/Program.cs
+++ b/DllImportGenerator/Demo/Program.cs
@@ -31,6 +31,13 @@ static void Main(string[] args)
c = b;
NativeExportsNE.Sum(a, ref c);
Console.WriteLine($"{a} + {b} = {c}");
+
+ SafeHandleTests tests = new SafeHandleTests();
+
+ tests.ReturnValue_CreatesSafeHandle();
+ tests.ByValue_CorrectlyUnwrapsHandle();
+ tests.ByRefSameValue_UsesSameHandleInstance();
+ tests.ByRefDifferentValue_UsesNewHandleInstance();
}
}
}
diff --git a/DllImportGenerator/Demo/SafeHandleTests.cs b/DllImportGenerator/Demo/SafeHandleTests.cs
new file mode 100644
index 000000000000..bb21cfb565ca
--- /dev/null
+++ b/DllImportGenerator/Demo/SafeHandleTests.cs
@@ -0,0 +1,74 @@
+
+using System.Runtime.InteropServices;
+using Microsoft.Win32.SafeHandles;
+using Xunit;
+
+namespace Demo
+{
+ partial class NativeExportsNE
+ {
+ public class NativeExportsSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
+ {
+ private NativeExportsSafeHandle() : base(true)
+ {
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ Assert.True(NativeExportsNE.ReleaseHandle(handle));
+ return true;
+ }
+ }
+
+ [GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "alloc_handle")]
+ public static partial NativeExportsSafeHandle AllocateHandle();
+
+ [GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "release_handle")]
+ [return:MarshalAs(UnmanagedType.I1)]
+ private static partial bool ReleaseHandle(nint handle);
+
+ [GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "is_handle_alive")]
+ [return:MarshalAs(UnmanagedType.I1)]
+ public static partial bool IsHandleAlive(NativeExportsSafeHandle handle);
+
+ [GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "modify_handle")]
+ public static partial void ModifyHandle(ref NativeExportsSafeHandle handle, [MarshalAs(UnmanagedType.I1)] bool newHandle);
+ }
+
+ public class SafeHandleTests
+ {
+ [Fact]
+ public void ReturnValue_CreatesSafeHandle()
+ {
+ using NativeExportsNE.NativeExportsSafeHandle handle = NativeExportsNE.AllocateHandle();
+ Assert.False(handle.IsClosed);
+ Assert.False(handle.IsInvalid);
+ }
+
+ [Fact]
+ public void ByValue_CorrectlyUnwrapsHandle()
+ {
+ using NativeExportsNE.NativeExportsSafeHandle handle = NativeExportsNE.AllocateHandle();
+ Assert.True(NativeExportsNE.IsHandleAlive(handle));
+ }
+
+ [Fact]
+ public void ByRefSameValue_UsesSameHandleInstance()
+ {
+ using NativeExportsNE.NativeExportsSafeHandle handleToDispose = NativeExportsNE.AllocateHandle();
+ NativeExportsNE.NativeExportsSafeHandle handle = handleToDispose;
+ NativeExportsNE.ModifyHandle(ref handle, false);
+ Assert.Same(handleToDispose, handle);
+ }
+
+ [Fact]
+ public void ByRefDifferentValue_UsesNewHandleInstance()
+ {
+ using NativeExportsNE.NativeExportsSafeHandle handleToDispose = NativeExportsNE.AllocateHandle();
+ NativeExportsNE.NativeExportsSafeHandle handle = handleToDispose;
+ NativeExportsNE.ModifyHandle(ref handle, true);
+ Assert.NotSame(handleToDispose, handle);
+ handle.Dispose();
+ }
+ }
+}
\ No newline at end of file
diff --git a/DllImportGenerator/Directory.Build.props b/DllImportGenerator/Directory.Build.props
index 81194ffe87da..61ba7c1afc32 100644
--- a/DllImportGenerator/Directory.Build.props
+++ b/DllImportGenerator/Directory.Build.props
@@ -2,6 +2,7 @@
3.8.0-3.final
+ 2.4.1
diff --git a/DllImportGenerator/DllImportGenerator.Test/Compiles.cs b/DllImportGenerator/DllImportGenerator.Test/Compiles.cs
index 33b741e8ad25..3c66f36bbe6d 100644
--- a/DllImportGenerator/DllImportGenerator.Test/Compiles.cs
+++ b/DllImportGenerator/DllImportGenerator.Test/Compiles.cs
@@ -90,6 +90,7 @@ public static IEnumerable