Skip to content

Commit

Permalink
Slight push toward CCW.
Browse files Browse the repository at this point in the history
I have 2 ideas. CCW seems to be rathe rcompilcated. But RCW require just having Dictionary<IntrPtr, object> to keep track of RCWs. Most likely this should be some weekly referenced object, which should release COM object when no longer referenced. In fnalizer probably.
  • Loading branch information
kant2002 committed Feb 16, 2021
1 parent c22f302 commit 31f8365
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
<PropertyGroup>
<DefineConstants Condition="'$(FeatureCominterop)' == 'true'">FEATURE_COMINTEROP;$(DefineConstants)</DefineConstants>
</PropertyGroup>
<PropertyGroup>
<FeatureComWrappers>false</FeatureComWrappers>
<FeatureComWrappers Condition="'$(TargetsWindows)' == 'true'">true</FeatureComWrappers>
</PropertyGroup>
<PropertyGroup>
<DefineConstants Condition="'$(FeatureComWrappers)' == 'true'">FEATURE_COMWRAPPERS;$(DefineConstants)</DefineConstants>
</PropertyGroup>
<PropertyGroup>
<FeaturePortableThreadPool Condition="'$(FeaturePortableThreadPool)' == ''">false</FeaturePortableThreadPool>
<FeaturePortableThreadPool Condition="'$(TargetsUnix)' == 'true'">true</FeaturePortableThreadPool>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,7 @@ internal enum ComWrappersScenario
/// </summary>
public abstract partial class ComWrappers
{
private ConditionalWeakTable<object, IntPtr> _ccwCache = new ConditionalWeakTable<object, IntPtr>();
public static unsafe IUnknownVftbl IUnknownVftbl => Unsafe.AsRef<IUnknownVftbl>(IUnknownVftblPtr.ToPointer());

internal static IntPtr IUnknownVftblPtr { get; }

static unsafe ComWrappers()
{
GetIUnknownImpl(out var qi, out var addRef, out var release);

IUnknownVftblPtr = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(IUnknownVftbl), sizeof(IUnknownVftbl));
(*(IUnknownVftbl*)IUnknownVftblPtr) = new IUnknownVftbl
{
QueryInterface = (delegate* unmanaged[Stdcall]<IntPtr, ref Guid, out IntPtr, int>)qi,
AddRef = (delegate* unmanaged[Stdcall]<IntPtr, uint>)addRef,
Release = (delegate* unmanaged[Stdcall]<IntPtr, uint>)release,
};
}
private static readonly ConditionalWeakTable<object, object> CCWTable = new ConditionalWeakTable<object, object>();

/// <summary>
/// ABI for function dispatch of a COM interface.
Expand Down Expand Up @@ -113,56 +97,63 @@ public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfa
/// <remarks>
/// If <paramref name="impl" /> is <c>null</c>, the global instance (if registered) will be used.
/// </remarks>
private static bool TryGetOrCreateComInterfaceForObjectInternal(ComWrappers impl, object instance, CreateComInterfaceFlags flags, out IntPtr retValue)
private static unsafe bool TryGetOrCreateComInterfaceForObjectInternal(ComWrappers impl, object instance, CreateComInterfaceFlags flags, out IntPtr retValue)
{
if (instance == null)
throw new ArgumentNullException(nameof(instance));

bool success = true;
IntPtr retValue = _ccwCache.GetValue(instance, (c) =>
object ccwValue = CCWTable.GetValue(instance, (c) =>
{
var vtblEntries = impl.ComputeVtables(instance, flags, out var count);
// Here I should create someing like that https://github.com/dotnet/runtime/blob/9f8aab73d93156933ae65a476204bf62c02f6537/src/coreclr/interop/comwrappers.cpp#L16
// Which would be saved to CCW cache.
// Creation of CCW is basically ManagedObjectWrapper::Create reimplementation.

// Maximum number of runtime supplied vtables.
ComInterfaceEntry* runtimeDefinedLocal = stackallock ComInterfaceEntry[4];
int runtimeDefinedCount = 0;

// Check if the caller will provide the IUnknown table.
if ((flags & CreateComInterfaceFlags.CallerDefinedIUnknown) == CreateComInterfaceFlags.None)
{
ComInterfaceEntry* curr = runtimeDefinedLocal[runtimeDefinedCount++];
curr->IID = __uuidof(IUnknown);
curr->Vtable = &ManagedObjectWrapper_IUnknownImpl;
}

// Check if the caller wants tracker support.
if ((flags & CreateComInterfaceFlags.TrackerSupport) == CreateComInterfaceFlags.TrackerSupport)
{
ComInterfaceEntry* curr = runtimeDefinedLocal[runtimeDefinedCount++];
curr->IID = __uuidof(IReferenceTrackerTarget);
curr->Vtable = &ManagedObjectWrapper_IReferenceTrackerTargetImpl;
}

// Compute size for ManagedObjectWrapper instance.
nuint totalRuntimeDefinedSize = runtimeDefinedCount * sizeof(ComInterfaceEntry);
nuint totalDefinedCount = (nuint)runtimeDefinedCount + userDefinedCount;

// Compute the total entry size of dispatch section.
nuint totalDispatchSectionCount = ComputeThisPtrForDispatchSection(totalDefinedCount) + totalDefinedCount;
nuint totalDispatchSectionSize = totalDispatchSectionCount * sizeof(void*);

// Allocate memory for the ManagedObjectWrapper.
char* wrapperMem = (char*)InteropLibImports::MemAlloc(sizeof(ManagedObjectWrapper) + totalRuntimeDefinedSize + totalDispatchSectionSize + ABI::AlignmentThisPtrMaxPadding, AllocScenario::ManagedObjectWrapper);

success = false;
return IntPtr.Zero;
var value = CreateCCW(impl, c, flags);
success = value != null;
return value;
});
retValue = IntPtr.Zero;
return success;
}

private static unsafe object CreateCCW(ComWrappers impl, object instance, CreateComInterfaceFlags flags)
{
var vtblEntries = impl.ComputeVtables(instance, flags, out var userDefinedCount);
// Here I should create someing like that https://github.com/dotnet/runtime/blob/9f8aab73d93156933ae65a476204bf62c02f6537/src/coreclr/interop/comwrappers.cpp#L16
// Which would be saved to CCW cache.
// Creation of CCW is basically ManagedObjectWrapper::Create reimplementation.

// Maximum number of runtime supplied vtables.
Span<ComInterfaceEntry> runtimeDefinedLocal = stackalloc ComInterfaceEntry[4];
int runtimeDefinedCount = 0;

// Check if the caller will provide the IUnknown table.
if ((flags & CreateComInterfaceFlags.CallerDefinedIUnknown) == CreateComInterfaceFlags.None)
{
ComInterfaceEntry curr = runtimeDefinedLocal[runtimeDefinedCount++];
curr.IID = typeof(IUnknownVftbl).GUID;
curr.Vtable = ComWrappersSupport.IUnknownVftblPtr;
}

// Check if the caller wants tracker support.
// if ((flags & CreateComInterfaceFlags.TrackerSupport) == CreateComInterfaceFlags.TrackerSupport)
// {
// ComInterfaceEntry* curr = runtimeDefinedLocal[runtimeDefinedCount++];
// curr->IID = __uuidof(IReferenceTrackerTarget);
// curr->Vtable = &ManagedObjectWrapper_IReferenceTrackerTargetImpl;
// }

// Compute size for ManagedObjectWrapper instance.
nuint totalRuntimeDefinedSize = (nuint)runtimeDefinedCount * (nuint)sizeof(ComInterfaceEntry);
nuint totalDefinedCount = (nuint)runtimeDefinedCount + (nuint)userDefinedCount;

// Compute the total entry size of dispatch section.
//nuint totalDispatchSectionCount = ComputeThisPtrForDispatchSection(totalDefinedCount) + totalDefinedCount;
//nuint totalDispatchSectionSize = totalDispatchSectionCount * (nuint)sizeof(void*);

// Allocate memory for the ManagedObjectWrapper.
//char* wrapperMem = (char*)InteropLibImports::MemAlloc(sizeof(ManagedObjectWrapper) + totalRuntimeDefinedSize + totalDispatchSectionSize + ABI::AlignmentThisPtrMaxPadding, AllocScenario::ManagedObjectWrapper);

return null;
}

// Called by the runtime to execute the abstract instance function
internal static unsafe void* CallComputeVtables(ComWrappersScenario scenario, ComWrappers? comWrappersImpl, object obj, CreateComInterfaceFlags flags, out int count)
{
Expand Down Expand Up @@ -299,17 +290,15 @@ private static bool TryGetOrCreateObjectForComInstanceInternal(
if (externalComObject == IntPtr.Zero)
throw new ArgumentNullException(nameof(externalComObject));

if (flags.HasFlag(CreateObjectFlags.Aggregation))
throw new NotImplementedException();

// If the inner is supplied the Aggregation flag should be set.
if (innerMaybe != IntPtr.Zero && !flags.HasFlag(CreateObjectFlags.Aggregation))
throw new InvalidOperationException(SR.InvalidOperation_SuppliedInnerMustBeMarkedAggregation);
// if (innerMaybe != IntPtr.Zero && !flags.HasFlag(CreateObjectFlags.Aggregation))
// throw new InvalidOperationException(SR.InvalidOperation_SuppliedInnerMustBeMarkedAggregation);

object? wrapperMaybeLocal = wrapperMaybe;
retValue = null;
return TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), impl.id, externalComObject, innerMaybe, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue));
}

private static bool TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack comWrappersImpl, long wrapperId, IntPtr externalComObject, IntPtr innerMaybe, CreateObjectFlags flags, ObjectHandleOnStack wrapper, ObjectHandleOnStack retValue)
{
throw new NotImplementedException();
}

Expand Down Expand Up @@ -369,19 +358,67 @@ public static void RegisterForMarshalling(ComWrappers instance)
/// <param name="fpQueryInterface">Function pointer to QueryInterface.</param>
/// <param name="fpAddRef">Function pointer to AddRef.</param>
/// <param name="fpRelease">Function pointer to Release.</param>
protected static void GetIUnknownImpl(out IntPtr fpQueryInterface, out IntPtr fpAddRef, out IntPtr fpRelease)
=> GetIUnknownImplInternal(out fpQueryInterface, out fpAddRef, out fpRelease);
protected internal static void GetIUnknownImpl(out IntPtr fpQueryInterface, out IntPtr fpAddRef, out IntPtr fpRelease)
=> ComWrappersSupport.GetIUnknownImplInternal(out fpQueryInterface, out fpAddRef, out fpRelease);

internal static int CallICustomQueryInterface(object customQueryInterfaceMaybe, ref Guid iid, out IntPtr ppObject)
{
var customQueryInterface = customQueryInterfaceMaybe as ICustomQueryInterface;
if (customQueryInterface is null)
{
ppObject = IntPtr.Zero;
return -1; // See TryInvokeICustomQueryInterfaceResult
}

return (int)customQueryInterface.GetInterface(ref iid, out ppObject);
}
}

private static void GetIUnknownImplInternal(out IntPtr fpQueryInterface, out IntPtr fpAddRef, out IntPtr fpRelease)
[Guid("00000000-0000-0000-C000-000000000046")]
[CLSCompliant(false)]
public unsafe struct IUnknownVftbl
{
private void* _QueryInterface;
public delegate* unmanaged[Stdcall]<IntPtr, ref Guid, out IntPtr, int> QueryInterface { get => (delegate* unmanaged[Stdcall]<IntPtr, ref Guid, out IntPtr, int>)_QueryInterface; set => _QueryInterface = (void*)value; }
private void* _AddRef;
public delegate* unmanaged[Stdcall]<IntPtr, uint> AddRef { get => (delegate* unmanaged[Stdcall]<IntPtr, uint>)_AddRef; set => _AddRef = (void*)value; }
private void* _Release;
public delegate* unmanaged[Stdcall]<IntPtr, uint> Release { get => (delegate* unmanaged[Stdcall]<IntPtr, uint>)_Release; set => _Release = (void*)value; }

public static IUnknownVftbl AbiToProjectionVftbl => ComWrappersSupport.IUnknownVftbl;
public static IntPtr AbiToProjectionVftblPtr => ComWrappersSupport.IUnknownVftblPtr;
}

internal class ComWrappersSupport
{
public static unsafe IUnknownVftbl IUnknownVftbl => Unsafe.AsRef<IUnknownVftbl>(IUnknownVftblPtr.ToPointer());

internal static IntPtr IUnknownVftblPtr { get; }

static unsafe ComWrappersSupport()
{
fpQueryInterface = (IntPtr)&ABI_QueryInterface;
fpAddRef = (IntPtr)&ABI_AddRef;
fpRelease = (IntPtr)&ABI_Release;
GetIUnknownImplInternal(out var qi, out var addRef, out var release);

IUnknownVftblPtr = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(IUnknownVftbl), sizeof(IUnknownVftbl));
(*(IUnknownVftbl*)IUnknownVftblPtr) = new IUnknownVftbl
{
QueryInterface = (delegate* unmanaged[Stdcall]<IntPtr, ref Guid, out IntPtr, int>)qi,
AddRef = (delegate* unmanaged[Stdcall]<IntPtr, uint>)addRef,
Release = (delegate* unmanaged[Stdcall]<IntPtr, uint>)release,
};
}

internal static unsafe void GetIUnknownImplInternal(out IntPtr fpQueryInterface, out IntPtr fpAddRef, out IntPtr fpRelease)
{
fpQueryInterface = (IntPtr)(delegate* unmanaged<IntPtr, Guid*, IntPtr*, int>)&ABI_QueryInterface;
fpAddRef = (IntPtr)(delegate* unmanaged<IntPtr, int>)&ABI_AddRef;
fpRelease = (IntPtr)(delegate* unmanaged<IntPtr, int>)&ABI_Release;
}

[UnmanagedCallersOnly]
private static int ABI_QueryInterface(IntPtr ppObject, ref Guid guid, out IntPtr returnValue)
private static unsafe int ABI_QueryInterface(IntPtr ppObject, Guid* guid, IntPtr* returnValue)
{
*returnValue = IntPtr.Zero;
return 0;
}

Expand All @@ -396,17 +433,5 @@ private static int ABI_Release(IntPtr ppObject)
{
return 0;
}

internal static int CallICustomQueryInterface(object customQueryInterfaceMaybe, ref Guid iid, out IntPtr ppObject)
{
var customQueryInterface = customQueryInterfaceMaybe as ICustomQueryInterface;
if (customQueryInterface is null)
{
ppObject = IntPtr.Zero;
return -1; // See TryInvokeICustomQueryInterfaceResult
}

return (int)customQueryInterface.GetInterface(ref iid, out ppObject);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public partial struct GCHandle

private static void InternalFree(IntPtr handle) => RuntimeImports.RhHandleFree(handle);

private static object InternalGet(IntPtr handle) => RuntimeImports.RhHandleGet(handle);
internal static object InternalGet(IntPtr handle) => RuntimeImports.RhHandleGet(handle);

private static void InternalSet(IntPtr handle, object value) => RuntimeImports.RhHandleSet(handle, value);
}
Expand Down

0 comments on commit 31f8365

Please sign in to comment.