From 4ab2b2bad7f3d1befab60448ed0b4d0a1b3e5c29 Mon Sep 17 00:00:00 2001 From: Andrii Kurdiumov Date: Wed, 21 Apr 2021 08:26:01 +0600 Subject: [PATCH] Add basic implementation of the ComWrappers (#653) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add basic implementation of the ComWrappers * Implement IUnknown interface for CCW * Introduce InternalComInterfaceDispatch * Pass type to which convert COM instance. * Create separate tests for ComWrappers feature * Read GUID directly from metadata instead of relying on Reflection Co-authored-by: Jan Kotas Co-authored-by: Michal Strehovský --- src/coreclr/nativeaot/Directory.Build.props | 4 + .../Runtime/CompilerHelpers/InteropHelpers.cs | 8 +- .../src/Resources/Strings.resx | 6 + .../src/System.Private.CoreLib.csproj | 1 + .../InteropServices/ComWrappers.CoreRT.cs | 460 ++++++++++++++++++ .../TypeSystem/Interop/IL/Marshaller.Aot.cs | 34 +- .../SmokeTests/ComWrappers/CMakeLists.txt | 7 + .../SmokeTests/ComWrappers/ComWrappers.cs | 162 ++++++ .../SmokeTests/ComWrappers/ComWrappers.csproj | 14 + .../ComWrappers/ComWrappersNative.cpp | 43 ++ .../nativeaot/SmokeTests/ComWrappers/rd.xml | 7 + .../nativeaot/SmokeTests/PInvoke/PInvoke.cs | 26 +- .../SmokeTests/PInvoke/PInvokeNative.cpp | 26 +- 13 files changed, 758 insertions(+), 40 deletions(-) create mode 100644 src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.CoreRT.cs create mode 100644 src/tests/nativeaot/SmokeTests/ComWrappers/CMakeLists.txt create mode 100644 src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappers.cs create mode 100644 src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappers.csproj create mode 100644 src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappersNative.cpp create mode 100644 src/tests/nativeaot/SmokeTests/ComWrappers/rd.xml diff --git a/src/coreclr/nativeaot/Directory.Build.props b/src/coreclr/nativeaot/Directory.Build.props index ad5a22e66dc5..a67c97aac2e1 100644 --- a/src/coreclr/nativeaot/Directory.Build.props +++ b/src/coreclr/nativeaot/Directory.Build.props @@ -57,6 +57,10 @@ FEATURE_COMINTEROP;$(DefineConstants) + + false + true + diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/Runtime/CompilerHelpers/InteropHelpers.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/Runtime/CompilerHelpers/InteropHelpers.cs index 18f92a22d490..d4840392553e 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/Runtime/CompilerHelpers/InteropHelpers.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/Runtime/CompilerHelpers/InteropHelpers.cs @@ -405,14 +405,20 @@ public static T GetCurrentCalleeDelegate() where T : class // constraint can' return PInvokeMarshal.GetCurrentCalleeDelegate(); } - public static IntPtr ConvertManagedComInterfaceToNative(object pUnk) + public static IntPtr ConvertManagedComInterfaceToNative(object pUnk, Guid interfaceGuid) { if (pUnk == null) { return IntPtr.Zero; } +#if TARGET_WINDOWS +#pragma warning disable CA1416 + return ComWrappers.ComInterfaceForObject(pUnk, interfaceGuid); +#pragma warning restore CA1416 +#else throw new PlatformNotSupportedException(SR.PlatformNotSupported_ComInterop); +#endif } public static object ConvertNativeComInterfaceToManaged(IntPtr pUnk) diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/Resources/Strings.resx b/src/coreclr/nativeaot/System.Private.CoreLib/src/Resources/Strings.resx index 56efba440df6..4654b1f0b9b1 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/Resources/Strings.resx +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/Resources/Strings.resx @@ -3151,7 +3151,13 @@ The argv[0] argument cannot include a double quote. + + Attempt to update previously set global instance. + Use of ResourceManager for custom types is disabled. Set the MSBuild Property CustomResourceTypesSupport to true in order to enable it. + + COM Interop requires ComWrapper instance registered for marshalling. + diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System.Private.CoreLib.csproj b/src/coreclr/nativeaot/System.Private.CoreLib/src/System.Private.CoreLib.csproj index 12528cdb9e48..2dca5df08ee8 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System.Private.CoreLib.csproj +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System.Private.CoreLib.csproj @@ -197,6 +197,7 @@ + diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.CoreRT.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.CoreRT.cs new file mode 100644 index 000000000000..1677210f90eb --- /dev/null +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.CoreRT.cs @@ -0,0 +1,460 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using Internal.Runtime.CompilerServices; + +namespace System.Runtime.InteropServices +{ + /// + /// Class for managing wrappers of COM IUnknown types. + /// + public abstract partial class ComWrappers + { + internal static IntPtr DefaultIUnknownVftblPtr { get; } = CreateDefaultIUnknownVftbl(); + + internal static Guid IID_IUnknown = new Guid(0x00000000, 0x0000, 0x0000, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46); + + private readonly ConditionalWeakTable _ccwTable = new ConditionalWeakTable(); + + /// + /// ABI for function dispatch of a COM interface. + /// + public unsafe partial struct ComInterfaceDispatch + { + /// + /// Given a from a generated Vtable, convert to the target type. + /// + /// Desired type. + /// Pointer supplied to Vtable function entry. + /// Instance of type associated with dispatched function call. + public static unsafe T GetInstance(ComInterfaceDispatch* dispatchPtr) where T : class + { + ManagedObjectWrapper* comInstance = ToManagedObjectWrapper(dispatchPtr); + return Unsafe.As(RuntimeImports.RhHandleGet(comInstance->Target)); + } + + internal static unsafe ManagedObjectWrapper* ToManagedObjectWrapper(ComInterfaceDispatch* dispatchPtr) + { + return ((InternalComInterfaceDispatch*)dispatchPtr)->_thisPtr; + } + } + + internal unsafe struct InternalComInterfaceDispatch + { + public IntPtr Vtable; + internal ManagedObjectWrapper* _thisPtr; + } + + internal unsafe struct ManagedObjectWrapper + { + public IntPtr Target; // This is GC Handle + public uint RefCount; + + public int UserDefinedCount; + public ComInterfaceEntry* UserDefined; + internal InternalComInterfaceDispatch* Dispatches; + + internal CreateComInterfaceFlags Flags; + + public uint AddRef() + { + return Interlocked.Increment(ref RefCount); + } + + public uint Release() + { + Debug.Assert(RefCount != 0); + return Interlocked.Decrement(ref RefCount); + } + + public unsafe int QueryInterface(in Guid riid, out IntPtr ppvObject) + { + ppvObject = AsRuntimeDefined(in riid); + if (ppvObject == IntPtr.Zero) + { + ppvObject = AsUserDefined(in riid); + if (ppvObject == IntPtr.Zero) + return HResults.COR_E_INVALIDCAST; + } + + AddRef(); + return HResults.S_OK; + } + + public IntPtr As(in Guid riid) + { + // Find target interface and return dispatcher or null if not found. + IntPtr typeMaybe = AsRuntimeDefined(in riid); + if (typeMaybe == IntPtr.Zero) + typeMaybe = AsUserDefined(in riid); + + return typeMaybe; + } + + public unsafe void Destroy() + { + if (Target == IntPtr.Zero) + { + return; + } + + RuntimeImports.RhHandleFree(Target); + Target = IntPtr.Zero; + } + + private unsafe IntPtr AsRuntimeDefined(in Guid riid) + { + if ((Flags & CreateComInterfaceFlags.CallerDefinedIUnknown) == CreateComInterfaceFlags.None) + { + if (riid == IID_IUnknown) + { + return (IntPtr)(Dispatches + UserDefinedCount); + } + } + + return IntPtr.Zero; + } + + private unsafe IntPtr AsUserDefined(in Guid riid) + { + for (int i = 0; i < UserDefinedCount; ++i) + { + if (UserDefined[i].IID == riid) + { + return (IntPtr)(Dispatches + i); + } + } + + return IntPtr.Zero; + } + } + + internal unsafe class ManagedObjectWrapperHolder + { + private ManagedObjectWrapper* _wrapper; + + public ManagedObjectWrapperHolder(ManagedObjectWrapper* wrapper) + { + _wrapper = wrapper; + } + + public unsafe IntPtr ComIp => _wrapper->As(in ComWrappers.IID_IUnknown); + + ~ManagedObjectWrapperHolder() + { + // Release GC handle created when MOW was built. + _wrapper->Destroy(); + Marshal.FreeCoTaskMem((IntPtr)_wrapper); + } + } + + internal unsafe struct IUnknownVftbl + { + public delegate* unmanaged QueryInterface; + public delegate* unmanaged AddRef; + public delegate* unmanaged Release; + } + +#if false + /// + /// Globally registered instance of the ComWrappers class for reference tracker support. + /// + private static ComWrappers? s_globalInstanceForTrackerSupport; +#endif + + /// + /// Globally registered instance of the ComWrappers class for marshalling. + /// + private static ComWrappers? s_globalInstanceForMarshalling; + + /// + /// Create a COM representation of the supplied object that can be passed to a non-managed environment. + /// + /// The managed object to expose outside the .NET runtime. + /// Flags used to configure the generated interface. + /// The generated COM interface that can be passed outside the .NET runtime. + /// + /// If a COM representation was previously created for the specified using + /// this instance, the previously created COM interface will be returned. + /// If not, a new one will be created. + /// + public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfaceFlags flags) + { + if (instance == null) + throw new ArgumentNullException(nameof(instance)); + + ManagedObjectWrapperHolder ccwValue; + if (_ccwTable.TryGetValue(instance, out ccwValue)) + { + return ccwValue.ComIp; + } + + ccwValue = _ccwTable.GetValue(instance, (c) => + { + ManagedObjectWrapper* value = CreateCCW(this, c, flags); + return new ManagedObjectWrapperHolder(value); + }); + return ccwValue.ComIp; + } + + private static unsafe ManagedObjectWrapper* CreateCCW(ComWrappers impl, object instance, CreateComInterfaceFlags flags) + { + ComInterfaceEntry* userDefined = impl.ComputeVtables(instance, flags, out int userDefinedCount); + + // Maximum number of runtime supplied vtables. + Span runtimeDefinedVtable = stackalloc IntPtr[4]; + int runtimeDefinedCount = 0; + + // Check if the caller will provide the IUnknown table. + if ((flags & CreateComInterfaceFlags.CallerDefinedIUnknown) == CreateComInterfaceFlags.None) + { + runtimeDefinedVtable[runtimeDefinedCount++] = DefaultIUnknownVftblPtr; + } + + // Compute size for ManagedObjectWrapper instance. + int totalDefinedCount = runtimeDefinedCount + userDefinedCount; + + // Allocate memory for the ManagedObjectWrapper. + IntPtr wrapperMem = Marshal.AllocCoTaskMem( + sizeof(ManagedObjectWrapper) + totalDefinedCount * sizeof(InternalComInterfaceDispatch)); + + // Compute the dispatch section offset and ensure it is aligned. + ManagedObjectWrapper* mow = (ManagedObjectWrapper*)wrapperMem; + + // Dispatches follow immediately after ManagedObjectWrapper + InternalComInterfaceDispatch* pDispatches = (InternalComInterfaceDispatch*)(wrapperMem + sizeof(ManagedObjectWrapper)); + for (int i = 0; i < totalDefinedCount; i++) + { + pDispatches[i].Vtable = (i < userDefinedCount) ? userDefined[i].Vtable : runtimeDefinedVtable[i - userDefinedCount]; + pDispatches[i]._thisPtr = mow; + } + + mow->Target = RuntimeImports.RhHandleAlloc(instance, GCHandleType.Normal); + mow->RefCount = 0; + mow->UserDefinedCount = userDefinedCount; + mow->UserDefined = userDefined; + mow->Flags = flags; + mow->Dispatches = pDispatches; + return mow; + } + + /// + /// Get the currently registered managed object or creates a new managed object and registers it. + /// + /// Object to import for usage into the .NET runtime. + /// Flags used to describe the external object. + /// Returns a managed object associated with the supplied external COM object. + /// + /// If a managed object was previously created for the specified + /// using this instance, the previously created object will be returned. + /// If not, a new one will be created. + /// + public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags) + { + object? obj; + if (!TryGetOrCreateObjectForComInstanceInternal(this, externalComObject, IntPtr.Zero, flags, null, out obj)) + throw new ArgumentNullException(nameof(externalComObject)); + + return obj!; + } + + /// + /// Get the currently registered managed object or uses the supplied managed object and registers it. + /// + /// Object to import for usage into the .NET runtime. + /// Flags used to describe the external object. + /// The to be used as the wrapper for the external object + /// Returns a managed object associated with the supplied external COM object. + /// + /// If the instance already has an associated external object a will be thrown. + /// + public object GetOrRegisterObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags, object wrapper) + { + return GetOrRegisterObjectForComInstance(externalComObject, flags, wrapper, IntPtr.Zero); + } + + /// + /// Get the currently registered managed object or uses the supplied managed object and registers it. + /// + /// Object to import for usage into the .NET runtime. + /// Flags used to describe the external object. + /// The to be used as the wrapper for the external object + /// Inner for COM aggregation scenarios + /// Returns a managed object associated with the supplied external COM object. + /// + /// This method override is for registering an aggregated COM instance with its associated inner. The inner + /// will be released when the associated wrapper is eventually freed. Note that it will be released on a thread + /// in an unknown apartment state. If the supplied inner is not known to be a free-threaded instance then + /// it is advised to not supply the inner. + /// + /// If the instance already has an associated external object a will be thrown. + /// + public object GetOrRegisterObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags, object wrapper, IntPtr inner) + { + if (wrapper == null) + throw new ArgumentNullException(nameof(wrapper)); + + object? obj; + if (!TryGetOrCreateObjectForComInstanceInternal(this, externalComObject, inner, flags, wrapper, out obj)) + throw new ArgumentNullException(nameof(externalComObject)); + + return obj!; + } + + /// + /// Get the currently registered managed object or creates a new managed object and registers it. + /// + /// The implementation to use when creating the managed object. + /// Object to import for usage into the .NET runtime. + /// The inner instance if aggregation is involved + /// Flags used to describe the external object. + /// The to be used as the wrapper for the external object. + /// The managed object associated with the supplied external COM object or null if it could not be created. + /// Returns true if a managed object could be retrieved/created, false otherwise + /// + /// If is null, the global instance (if registered) will be used. + /// + private static bool TryGetOrCreateObjectForComInstanceInternal( + ComWrappers impl, + IntPtr externalComObject, + IntPtr innerMaybe, + CreateObjectFlags flags, + object? wrapperMaybe, + out object? retValue) + { + if (externalComObject == IntPtr.Zero) + throw new ArgumentNullException(nameof(externalComObject)); + + if (flags.HasFlag(CreateObjectFlags.Aggregation)) + throw new NotImplementedException(); + + object? wrapperMaybeLocal = wrapperMaybe; + retValue = null; + throw new NotImplementedException(); + } + + /// + /// Register a instance to be used as the global instance for reference tracker support. + /// + /// Instance to register + /// + /// This function can only be called a single time. Subsequent calls to this function will result + /// in a being thrown. + /// + /// Scenarios where this global instance may be used are: + /// * Object tracking via the and flags. + /// + public static void RegisterForTrackerSupport(ComWrappers instance) + { +#if false + if (instance == null) + throw new ArgumentNullException(nameof(instance)); + + if (null != Interlocked.CompareExchange(ref s_globalInstanceForTrackerSupport, instance, null)) + { + throw new InvalidOperationException(SR.InvalidOperation_ResetGlobalComWrappersInstance); + } +#else + throw new NotImplementedException(); +#endif + } + + /// + /// Register a instance to be used as the global instance for marshalling in the runtime. + /// + /// Instance to register + /// + /// This function can only be called a single time. Subsequent calls to this function will result + /// in a being thrown. + /// + /// Scenarios where this global instance may be used are: + /// * Usage of COM-related Marshal APIs + /// * P/Invokes with COM-related types + /// * COM activation + /// + public static void RegisterForMarshalling(ComWrappers instance) + { + if (instance == null) + throw new ArgumentNullException(nameof(instance)); + + if (null != Interlocked.CompareExchange(ref s_globalInstanceForMarshalling, instance, null)) + { + throw new InvalidOperationException(SR.InvalidOperation_ResetGlobalComWrappersInstance); + } + } + + /// + /// Get the runtime provided IUnknown implementation. + /// + /// Function pointer to QueryInterface. + /// Function pointer to AddRef. + /// Function pointer to Release. + protected internal static unsafe void GetIUnknownImpl(out IntPtr fpQueryInterface, out IntPtr fpAddRef, out IntPtr fpRelease) + { + fpQueryInterface = (IntPtr)(delegate* unmanaged)&ComWrappers.IUnknown_QueryInterface; + fpAddRef = (IntPtr)(delegate* unmanaged)&ComWrappers.IUnknown_AddRef; + fpRelease = (IntPtr)(delegate* unmanaged)&ComWrappers.IUnknown_Release; + } + + internal static IntPtr ComInterfaceForObject(object instance) + { + if (s_globalInstanceForMarshalling == null) + { + throw new InvalidOperationException(SR.InvalidOperation_ComInteropRequireComWrapperInstance); + } + + return s_globalInstanceForMarshalling.GetOrCreateComInterfaceForObject(instance, CreateComInterfaceFlags.None); + } + + internal static unsafe IntPtr ComInterfaceForObject(object instance, Guid targetIID) + { + IntPtr unknownPtr = ComInterfaceForObject(instance); + IntPtr comObjectInterface; + ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)unknownPtr); + int resultCode = wrapper->QueryInterface(in targetIID, out comObjectInterface); + if (resultCode != 0) + { + throw new PlatformNotSupportedException(SR.PlatformNotSupported_ComInterop); + } + + return comObjectInterface; + } + + [UnmanagedCallersOnly] + internal static unsafe int IUnknown_QueryInterface(IntPtr pThis, Guid* guid, IntPtr* ppObject) + { + ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis); + return wrapper->QueryInterface(in *guid, out *ppObject); + } + + [UnmanagedCallersOnly] + internal static unsafe uint IUnknown_AddRef(IntPtr pThis) + { + ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis); + return wrapper->AddRef(); + } + + [UnmanagedCallersOnly] + internal static unsafe uint IUnknown_Release(IntPtr pThis) + { + ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis); + uint refcount = wrapper->Release(); + if (refcount == 0) + { + wrapper->Destroy(); + } + + return refcount; + } + + private static unsafe IntPtr CreateDefaultIUnknownVftbl() + { + IntPtr* vftbl = (IntPtr*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComWrappers), 3 * sizeof(IntPtr)); + GetIUnknownImpl(out vftbl[0], out vftbl[1], out vftbl[2]); + return (IntPtr)vftbl; + } + } +} diff --git a/src/coreclr/tools/Common/TypeSystem/Interop/IL/Marshaller.Aot.cs b/src/coreclr/tools/Common/TypeSystem/Interop/IL/Marshaller.Aot.cs index 6851eb862941..59f6523ece25 100644 --- a/src/coreclr/tools/Common/TypeSystem/Interop/IL/Marshaller.Aot.cs +++ b/src/coreclr/tools/Common/TypeSystem/Interop/IL/Marshaller.Aot.cs @@ -2,12 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Buffers.Binary; using System.Runtime.InteropServices; using Internal.IL.Stubs; using Internal.IL; using Debug = System.Diagnostics.Debug; using ILLocalVariable = Internal.IL.Stubs.ILLocalVariable; +using Internal.TypeSystem.Ecma; +using System.Reflection.Metadata; namespace Internal.TypeSystem.Interop { @@ -871,8 +874,35 @@ protected override void AllocAndTransformManagedToNative(ILCodeStream codeStream { ILEmitter emitter = _ilCodeStreams.Emitter; - var helper = Context.GetHelperEntryPoint("InteropHelpers", "ConvertManagedComInterfaceToNative"); + MethodDesc helper = Context.GetHelperEntryPoint("InteropHelpers", "ConvertManagedComInterfaceToNative"); LoadManagedValue(codeStream); + CustomAttributeValue? guidAttributeValue = (this.ManagedParameterType as EcmaType)? + .GetDecodedCustomAttribute("System.Runtime.InteropServices", "GuidAttribute"); + if (guidAttributeValue == null) + { + throw new NotSupportedException(); + } + + var guidValue = (string)guidAttributeValue.Value.FixedArguments[0].Value; + Span bytes = Guid.Parse(guidValue).ToByteArray(); + codeStream.EmitLdc(BinaryPrimitives.ReadInt32LittleEndian(bytes)); + codeStream.EmitLdc(BinaryPrimitives.ReadInt16LittleEndian(bytes.Slice(4))); + codeStream.EmitLdc(BinaryPrimitives.ReadInt16LittleEndian(bytes.Slice(6))); + for (int i = 8; i < 16; i++) + codeStream.EmitLdc(bytes[i]); + + MetadataType guidType = Context.SystemModule.GetKnownType("System", "Guid"); + var int32Type = Context.GetWellKnownType(WellKnownType.Int32); + var int16Type = Context.GetWellKnownType(WellKnownType.Int16); + var byteType = Context.GetWellKnownType(WellKnownType.Byte); + var sig = new MethodSignature( + MethodSignatureFlags.None, + genericParameterCount: 0, + returnType: Context.GetWellKnownType(WellKnownType.Void), + parameters: new TypeDesc[] { int32Type, int16Type, int16Type, byteType, byteType, byteType, byteType, byteType, byteType, byteType, byteType }); + MethodDesc guidCtorHandleMethod = + guidType.GetKnownMethod(".ctor", sig); + codeStream.Emit(ILOpcode.newobj, emitter.NewToken(guidCtorHandleMethod)); codeStream.Emit(ILOpcode.call, emitter.NewToken(helper)); @@ -883,7 +913,7 @@ protected override void AllocAndTransformNativeToManaged(ILCodeStream codeStream { ILEmitter emitter = _ilCodeStreams.Emitter; - var helper = Context.GetHelperEntryPoint("InteropHelpers", "ConvertNativeComInterfaceToManaged"); + MethodDesc helper = Context.GetHelperEntryPoint("InteropHelpers", "ConvertNativeComInterfaceToManaged"); LoadNativeValue(codeStream); codeStream.Emit(ILOpcode.call, emitter.NewToken(helper)); diff --git a/src/tests/nativeaot/SmokeTests/ComWrappers/CMakeLists.txt b/src/tests/nativeaot/SmokeTests/ComWrappers/CMakeLists.txt new file mode 100644 index 000000000000..a631acce2c89 --- /dev/null +++ b/src/tests/nativeaot/SmokeTests/ComWrappers/CMakeLists.txt @@ -0,0 +1,7 @@ +project (ComWrappersNative) +include_directories(${INC_PLATFORM_DIR}) + +add_library (ComWrappersNative SHARED ComWrappersNative.cpp) + +# add the install targets +install (TARGETS ComWrappersNative DESTINATION bin) diff --git a/src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappers.cs b/src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappers.cs new file mode 100644 index 000000000000..4ed7ef805dc1 --- /dev/null +++ b/src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappers.cs @@ -0,0 +1,162 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; + +namespace ComWrappersTests +{ + internal class Program + { + [DynamicDependency(DynamicallyAccessedMemberTypes.PublicMethods, typeof(IComInterface))] + public static int Main(string[] args) + { + TestComInteropNullPointers(); + TestComInteropRegistrationRequired(); + TestComInteropReleaseProcess(); + return 100; + } + + public static void ThrowIfNotEquals(T expected, T actual, string message) + { + if (!expected.Equals(actual)) + { + message += "\nExpected: " + expected.ToString() + "\n"; + message += "Actual: " + actual.ToString() + "\n"; + throw new Exception(message); + } + } + + public static void ThrowIfNotEquals(bool expected, bool actual, string message) + { + ThrowIfNotEquals(expected ? 1 : 0, actual ? 1 : 0, message); + } + + [DllImport("ComWrappersNative", CallingConvention = CallingConvention.StdCall)] + static extern bool IsNULL(IComInterface foo); + + [DllImport("ComWrappersNative", CallingConvention = CallingConvention.StdCall)] + static extern int CaptureComPointer(IComInterface foo); + + [DllImport("ComWrappersNative", CallingConvention = CallingConvention.StdCall)] + static extern void ReleaseComPointer(); + + public static void TestComInteropNullPointers() + { + Console.WriteLine("Testing Marshal APIs for COM interfaces"); + IComInterface comPointer = null; + var result = IsNULL(comPointer); + ThrowIfNotEquals(true, IsNULL(comPointer), "COM interface marshalling null check failed"); + } + + public static void TestComInteropRegistrationRequired() + { + Console.WriteLine("Testing COM Interop registration process"); + ComObject target = new ComObject(); + try + { + CaptureComPointer(target); + throw new Exception("Cannot work without ComWrappers.RegisterForMarshalling called"); + } + catch (InvalidOperationException) + { + } + } + + public static void TestComInteropReleaseProcess() + { + Console.WriteLine("Testing CCW release process"); + ComWrappers wrapper = new SimpleComWrapper(); + ComWrappers.RegisterForMarshalling(wrapper); + WeakReference comPointerHolder = CreateComReference(); + + GC.Collect(); + ThrowIfNotEquals(true, comPointerHolder.IsAlive, ".NET object should be alive"); + + ReleaseComPointer(); + + GC.Collect(); + ThrowIfNotEquals(false, comPointerHolder.IsAlive, ".NET object should be disposed by then"); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static WeakReference CreateComReference() + { + ComObject target = new ComObject(); + WeakReference comPointerHolder = new WeakReference(target); + + int result = CaptureComPointer(target); + ThrowIfNotEquals(0, result, "Seems to be COM marshalling behave stragerly."); + ThrowIfNotEquals(11, target.TestResult, "Call to method should work"); + + return comPointerHolder; + } + } + + [ComImport] + [ComVisible(true)] + [Guid("111e91ef-1887-4afd-81e3-70cf08e715d8")] + [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] + public interface IComInterface + { + int DoWork(int param); + } + + public class ComObject : IComInterface + { + public int TestResult; + public int DoWork(int param) + { + this.TestResult += param; + return 0; + } + } + + internal unsafe class SimpleComWrapper : ComWrappers + { + static ComInterfaceEntry* wrapperEntry; + + static SimpleComWrapper() + { + IntPtr* vtbl = (IntPtr*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(IComInterface), 4 * sizeof(IntPtr)); + GetIUnknownImpl(out vtbl[0], out vtbl[1], out vtbl[2]); + vtbl[3] = (IntPtr)(delegate* unmanaged)&IComInterfaceProxy.DoWork; + + var comInterfaceEntryMemory = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(IComInterface), sizeof(ComInterfaceEntry)); + wrapperEntry = (ComInterfaceEntry*)comInterfaceEntryMemory; + wrapperEntry->IID = new Guid("111e91ef-1887-4afd-81e3-70cf08e715d8"); + wrapperEntry->Vtable = (IntPtr)vtbl; + } + + protected override unsafe ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) + { + if (obj is not IComInterface) + throw new Exception(); + count = 1; + return wrapperEntry; + } + + protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flags) + { + return null; + } + + protected override void ReleaseObjects(System.Collections.IEnumerable objects) + { + } + } + + internal unsafe class IComInterfaceProxy + { + [UnmanagedCallersOnly] + public static int DoWork(IntPtr thisPtr, int param) + { + var inst = ComWrappers.ComInterfaceDispatch.GetInstance((ComWrappers.ComInterfaceDispatch*)thisPtr); + return inst.DoWork(param); + } + } +} diff --git a/src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappers.csproj b/src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappers.csproj new file mode 100644 index 000000000000..22ae33c751de --- /dev/null +++ b/src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappers.csproj @@ -0,0 +1,14 @@ + + + Exe + true + true + + + + + + + + + diff --git a/src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappersNative.cpp b/src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappersNative.cpp new file mode 100644 index 000000000000..027692bad0e1 --- /dev/null +++ b/src/tests/nativeaot/SmokeTests/ComWrappers/ComWrappersNative.cpp @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#include +#include +#include +#ifdef TARGET_WINDOWS +#include +#define DLL_EXPORT extern "C" __declspec(dllexport) +#else +#include +#define DLL_EXPORT extern "C" __attribute((visibility("default"))) +#endif + +#if !defined(__stdcall) +#define __stdcall +#endif + +DLL_EXPORT bool __stdcall IsNULL(void *a) +{ + return a == NULL; +} + +#ifdef TARGET_WINDOWS +class IComInterface: public IUnknown +{ +public: + virtual HRESULT STDMETHODCALLTYPE DoWork(int param) = 0; +}; +GUID IID_IComInterface = { 0x111e91ef, 0x1887, 0x4afd, { 0x81, 0xe3, 0x70, 0xcf, 0x08, 0xe7, 0x15, 0xd8 } }; + +IComInterface* capturedComObject; +DLL_EXPORT int __stdcall CaptureComPointer(IComInterface* pUnk) +{ + capturedComObject = pUnk; + return capturedComObject->DoWork(11); +} + +DLL_EXPORT void ReleaseComPointer() +{ + capturedComObject->Release(); +} +#endif diff --git a/src/tests/nativeaot/SmokeTests/ComWrappers/rd.xml b/src/tests/nativeaot/SmokeTests/ComWrappers/rd.xml new file mode 100644 index 000000000000..352d58c52e92 --- /dev/null +++ b/src/tests/nativeaot/SmokeTests/ComWrappers/rd.xml @@ -0,0 +1,7 @@ + + + + + + + diff --git a/src/tests/nativeaot/SmokeTests/PInvoke/PInvoke.cs b/src/tests/nativeaot/SmokeTests/PInvoke/PInvoke.cs index f20a82ff4e37..933a5710c9b1 100644 --- a/src/tests/nativeaot/SmokeTests/PInvoke/PInvoke.cs +++ b/src/tests/nativeaot/SmokeTests/PInvoke/PInvoke.cs @@ -209,9 +209,6 @@ private static extern bool VerifySizeParamIndex( [DllImport("PInvokeNative", CallingConvention = CallingConvention.StdCall)] static extern bool IsNULL(SequentialStruct[] foo); - [DllImport("PInvokeNative", CallingConvention = CallingConvention.StdCall)] - static extern bool IsNULL(IComInterface foo); - [StructLayout(LayoutKind.Sequential, CharSet= CharSet.Ansi, Pack = 4)] public unsafe struct InlineArrayStruct { @@ -297,9 +294,6 @@ public static int Main(string[] args) TestMarshalStructAPIs(); TestForwardDelegateWithUnmanagedCallersOnly(); - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - TestComInteropNullPointers(); - return 100; } @@ -360,7 +354,7 @@ private static void TestArrays() Foo[] arr_foo = null; ThrowIfNotEquals(true, IsNULL(arr_foo), "Blittable array null check failed"); - + arr_foo = new Foo[ArraySize]; for (int i = 0; i < ArraySize; i++) { @@ -801,7 +795,7 @@ private static void TestStruct() ssa[i].f1 = 0; ssa[i].f1 = i; ssa[i].f2 = i*i; - ssa[i].f3 = i.LowLevelToString(); + ssa[i].f3 = i.LowLevelToString(); } ThrowIfNotEquals(true, StructTest_Array(ssa, ssa.Length), "Array of struct marshalling failed"); @@ -1011,14 +1005,6 @@ public static unsafe void TestForwardDelegateWithUnmanagedCallersOnly() Action a = Marshal.GetDelegateForFunctionPointer((IntPtr)(void*)(delegate* unmanaged)&UnmanagedCallback); a(); } - - public static void TestComInteropNullPointers() - { - Console.WriteLine("Testing Marshal APIs for COM interfaces"); - IComInterface comPointer = null; - var result = IsNULL(comPointer); - ThrowIfNotEquals(true, IsNULL(comPointer), "COM interface marshalling null check failed"); - } } public class SafeMemoryHandle : SafeHandle //SafeHandle subclass @@ -1047,14 +1033,6 @@ override protected bool ReleaseHandle() } } //end of SafeMemoryHandle class - [ComImport] - [ComVisible(true)] - [Guid("D6DD68D1-86FD-4332-8666-9ABEDEA2D24C")] - [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] - public interface IComInterface - { - } - public static class LowLevelExtensions { // Int32.ToString() calls into glob/loc garbage that hits CppCodegen limitations diff --git a/src/tests/nativeaot/SmokeTests/PInvoke/PInvokeNative.cpp b/src/tests/nativeaot/SmokeTests/PInvoke/PInvokeNative.cpp index 5e20f6af0563..21877e3e7f54 100644 --- a/src/tests/nativeaot/SmokeTests/PInvoke/PInvokeNative.cpp +++ b/src/tests/nativeaot/SmokeTests/PInvoke/PInvokeNative.cpp @@ -20,7 +20,7 @@ #if (_MSC_VER >= 1400) // Check MSC version #pragma warning(push) #pragma warning(disable: 4996) // Disable deprecation -#endif +#endif void* MemAlloc(long bytes) { @@ -82,7 +82,7 @@ DLL_EXPORT int __stdcall CheckIncremental_Foo(Foo *array, int sz) return 1; } return 0; -} +} DLL_EXPORT int __stdcall Inc(int *val) { @@ -104,7 +104,7 @@ DLL_EXPORT int __stdcall VerifyByRefFoo(Foo *val) val->b++; return 0; -} +} DLL_EXPORT bool __stdcall GetNextChar(short *value) { @@ -129,7 +129,7 @@ int CompareUnicodeString(const unsigned short *val, const unsigned short *expect return 0; const unsigned short *p = val; const unsigned short *q = expected; - + while (*p && *q && *p == *q) { p++; @@ -196,7 +196,7 @@ DLL_EXPORT int __stdcall VerifyAnsiStringArray(char **val) void ToUpper(char *val) { - if (val == NULL) + if (val == NULL) return; char *p = val; while (*p != '\0') @@ -236,7 +236,7 @@ DLL_EXPORT int __stdcall VerifyUnicodeStringOut(unsigned short **val) unsigned short expected[] = { 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd', 0 }; for (int i = 0; i < 12; i++) p[i] = expected[i]; - + *val = p; return 1; } @@ -252,7 +252,7 @@ DLL_EXPORT int __stdcall VerifyUnicodeStringRef(unsigned short **val) if (!CompareUnicodeString(p, q)) return 0; - + MemFree(*val); p = (unsigned short*)MemAlloc(sizeof(unsigned short) * 13); @@ -401,7 +401,7 @@ DLL_EXPORT int __stdcall VerifyUnicodeStringBuilderOut(unsigned short *val) unsigned short src[] = { 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd', 0 }; for (int i = 0; i < 12; i++) val[i] = src[i]; - + return 1; } @@ -519,7 +519,7 @@ DLL_EXPORT bool __stdcall StructTest_Array(NativeSequentialStruct *nss, int leng { if (nss == NULL) return false; - + char expected[16]; for (int i = 0; i < 3; i++) @@ -562,7 +562,7 @@ DLL_EXPORT bool __stdcall InlineArrayTest(inlineStruct* p, inlineUnicodeStruct * return false; p->inlineArray[i] = i + 1; } - + if (CompareAnsiString(p->inlineString, "Hello") != 1) return false; @@ -613,7 +613,7 @@ DLL_EXPORT bool __stdcall StructTest_Nested(NativeNestedStruct nns) { if (nns.a != 100) return false; - + return StructTest_Explicit(nns.nes); } @@ -624,9 +624,9 @@ DLL_EXPORT bool __stdcall VerifyAnsiCharArrayIn(char *a) DLL_EXPORT bool __stdcall VerifyAnsiCharArrayOut(char *a) { - if (a == NULL) + if (a == NULL) return false; - + CopyAnsiString(a, "Hello World!"); return true; }