Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf improvements in CCW creation. #739

Merged
merged 3 commits into from
Feb 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 22 additions & 26 deletions src/WinRT.Runtime/ComWrappersSupport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ internal static object GetRuntimeClassCCWTypeIfAny(object obj)
return obj;
}

internal static List<ComInterfaceEntry> GetInterfaceTableEntries(object obj)
internal static List<ComInterfaceEntry> GetInterfaceTableEntries(Type type)
{
var entries = new List<ComInterfaceEntry>();
var objType = obj.GetType().GetRuntimeClassCCWType() ?? obj.GetType();
var objType = type.GetRuntimeClassCCWType() ?? type;
var interfaces = objType.GetInterfaces();
foreach (var iface in interfaces)
{
Expand Down Expand Up @@ -122,12 +122,12 @@ internal static List<ComInterfaceEntry> GetInterfaceTableEntries(object obj)
}
}

if (obj is Delegate)
if (type.IsDelegate())
{
entries.Add(new ComInterfaceEntry
{
IID = GuidGenerator.GetIID(obj.GetType()),
Vtable = (IntPtr)obj.GetType().GetHelperType().GetAbiToProjectionVftblPtr()
IID = GuidGenerator.GetIID(type),
Vtable = (IntPtr)type.GetHelperType().GetAbiToProjectionVftblPtr()
});
}

Expand All @@ -140,15 +140,15 @@ internal static List<ComInterfaceEntry> GetInterfaceTableEntries(object obj)
Vtable = (IntPtr)ifaceAbiType.GetAbiToProjectionVftblPtr()
});
}
else if (ShouldProvideIReference(obj))
else if (ShouldProvideIReference(type))
{
entries.Add(IPropertyValueEntry);
entries.Add(ProvideIReference(obj));
entries.Add(ProvideIReference(type));
}
else if (ShouldProvideIReferenceArray(obj))
else if (ShouldProvideIReferenceArray(type))
{
entries.Add(IPropertyValueEntry);
entries.Add(ProvideIReferenceArray(obj));
entries.Add(ProvideIReferenceArray(type));
}

entries.Add(new ComInterfaceEntry
Expand Down Expand Up @@ -178,17 +178,15 @@ internal static List<ComInterfaceEntry> GetInterfaceTableEntries(object obj)
return entries;
}

internal static (InspectableInfo inspectableInfo, List<ComInterfaceEntry> interfaceTableEntries) PregenerateNativeTypeInformation(object obj)
{
var interfaceTableEntries = GetInterfaceTableEntries(obj);
internal static (InspectableInfo inspectableInfo, List<ComInterfaceEntry> interfaceTableEntries) PregenerateNativeTypeInformation(Type type)
{
var interfaceTableEntries = GetInterfaceTableEntries(type);
var iids = new Guid[interfaceTableEntries.Count];
for (int i = 0; i < interfaceTableEntries.Count; i++)
{
iids[i] = interfaceTableEntries[i].IID;
}

Type type = obj.GetType();

if (type.FullName.StartsWith("ABI."))
{
type = Projections.FindCustomPublicTypeForAbiType(type) ?? type.Assembly.GetType(type.FullName.Substring("ABI.".Length)) ?? type;
Expand Down Expand Up @@ -337,23 +335,20 @@ internal static string GetRuntimeClassForTypeCreation(IInspectable inspectable,
return runtimeClassName;
}

private static bool ShouldProvideIReference(object obj)
private static bool ShouldProvideIReference(Type type)
{
return obj.GetType().IsValueType || obj is string || obj is Type || obj is Delegate;
return type.IsValueType || type == typeof(string) || type == typeof(Type) || type.IsDelegate();
}


private static ComInterfaceEntry IPropertyValueEntry =>
new ComInterfaceEntry
{
IID = global::WinRT.GuidGenerator.GetIID(typeof(global::Windows.Foundation.IPropertyValue)),
Vtable = ManagedIPropertyValueImpl.AbiToProjectionVftablePtr
};

private static ComInterfaceEntry ProvideIReference(object obj)
private static ComInterfaceEntry ProvideIReference(Type type)
{
Type type = obj.GetType();

if (type == typeof(int))
{
return new ComInterfaceEntry
Expand Down Expand Up @@ -482,7 +477,7 @@ private static ComInterfaceEntry ProvideIReference(object obj)
Vtable = BoxedValueIReferenceImpl<object>.AbiToProjectionVftablePtr
};
}
if (obj is Type)
if (type == typeof(Type))
{
return new ComInterfaceEntry
{
Expand All @@ -498,14 +493,15 @@ private static ComInterfaceEntry ProvideIReference(object obj)
};
}

private static bool ShouldProvideIReferenceArray(object obj)
private static bool ShouldProvideIReferenceArray(Type type)
{
return obj is Array arr && arr.Rank == 1 && arr.GetLowerBound(0) == 0 && !obj.GetType().GetElementType().IsArray;
// Check if one dimensional array with lower bound of 0
return type.IsArray && type == type.GetElementType().MakeArrayType() && !type.GetElementType().IsArray;
}

private static ComInterfaceEntry ProvideIReferenceArray(object obj)
private static ComInterfaceEntry ProvideIReferenceArray(Type arrayType)
{
Type type = obj.GetType().GetElementType();
Type type = arrayType.GetElementType();
if (type == typeof(int))
{
return new ComInterfaceEntry
Expand Down Expand Up @@ -634,7 +630,7 @@ private static ComInterfaceEntry ProvideIReferenceArray(object obj)
Vtable = BoxedArrayIReferenceArrayImpl<object>.AbiToProjectionVftablePtr
};
}
if (obj is Type)
if (type == typeof(Type))
{
return new ComInterfaceEntry
{
Expand Down
108 changes: 58 additions & 50 deletions src/WinRT.Runtime/ComWrappersSupport.net5.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
Expand All @@ -26,7 +27,7 @@ private static DefaultComWrappers DefaultComWrappersInstance
}
}

internal static readonly ConditionalWeakTable<object, InspectableInfo> InspectableInfoTable = new ConditionalWeakTable<object, InspectableInfo>();
internal static readonly ConditionalWeakTable<Type, InspectableInfo> InspectableInfoTable = new ConditionalWeakTable<Type, InspectableInfo>();
internal static readonly ThreadLocal<Type> CreateRCWType = new ThreadLocal<Type>();

private static ComWrappers _comWrappers;
Expand Down Expand Up @@ -67,7 +68,7 @@ private static ComWrappers ComWrappers
internal static unsafe InspectableInfo GetInspectableInfo(IntPtr pThis)
{
var _this = FindObject<object>(pThis);
return InspectableInfoTable.GetValue(_this, o => PregenerateNativeTypeInformation(o).inspectableInfo);
return InspectableInfoTable.GetValue(_this.GetType(), o => PregenerateNativeTypeInformation(o).inspectableInfo);
}

public static T CreateRcwForComObject<T>(IntPtr ptr)
Expand Down Expand Up @@ -193,7 +194,7 @@ private static Func<IInspectable, object> CreateFactoryForImplementationType(str

public class DefaultComWrappers : ComWrappers
{
private static ConditionalWeakTable<object, VtableEntriesCleanupScout> ComInterfaceEntryCleanupTable = new ConditionalWeakTable<object, VtableEntriesCleanupScout>();
private static readonly ConditionalWeakTable<Type, VtableEntries> TypeVtableEntryTable = new ConditionalWeakTable<Type, VtableEntries>();
public static unsafe IUnknownVftbl IUnknownVftbl => Unsafe.AsRef<IUnknownVftbl>(IUnknownVftblPtr.ToPointer());

internal static IntPtr IUnknownVftblPtr { get; }
Expand All @@ -212,51 +213,51 @@ static unsafe DefaultComWrappers()
}

protected override unsafe ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
{
if (IsRuntimeImplementedRCW(obj))
{
// If the object is a runtime-implemented RCW, let the runtime create a CCW.
count = 0;
return null;
}

var entries = ComWrappersSupport.GetInterfaceTableEntries(obj);

if (flags.HasFlag(CreateComInterfaceFlags.CallerDefinedIUnknown))
{
entries.Add(new ComInterfaceEntry
{
IID = typeof(IUnknownVftbl).GUID,
Vtable = IUnknownVftbl.AbiToProjectionVftblPtr
});
}

entries.Add(new ComInterfaceEntry
{
IID = typeof(IInspectable).GUID,
Vtable = IInspectable.Vftbl.AbiToProjectionVftablePtr
});

count = entries.Count;
ComInterfaceEntry* nativeEntries = (ComInterfaceEntry*)Marshal.AllocCoTaskMem(sizeof(ComInterfaceEntry) * count);

for (int i = 0; i < count; i++)
{
nativeEntries[i] = entries[i];
{
var vtableEntries = TypeVtableEntryTable.GetValue(obj.GetType(), (type) =>
{
if (IsRuntimeImplementedRCW(type))
{
// If the object is a runtime-implemented RCW, let the runtime create a CCW.
return new VtableEntries();
}

var entries = ComWrappersSupport.GetInterfaceTableEntries(type);

entries.Add(new ComInterfaceEntry
{
IID = typeof(IInspectable).GUID,
Vtable = IInspectable.Vftbl.AbiToProjectionVftablePtr
});

// This should be the last entry as it is included / excluded based on the flags.
entries.Add(new ComInterfaceEntry
{
IID = typeof(IUnknownVftbl).GUID,
Vtable = IUnknownVftbl.AbiToProjectionVftblPtr
});

return new VtableEntries(entries, type);
});

count = vtableEntries.Count;
if (count != 0 && !flags.HasFlag(CreateComInterfaceFlags.CallerDefinedIUnknown))
{
// The vtable list unconditionally has the last entry as IUnknown, but it should
// only be included if the flag is set. We achieve that by excluding the last entry
// from the count if the flag isn't set.
count -= 1;
}

ComInterfaceEntryCleanupTable.Add(obj, new VtableEntriesCleanupScout(nativeEntries));

return nativeEntries;
return vtableEntries.Data;
}

private static unsafe bool IsRuntimeImplementedRCW(object obj)
private static unsafe bool IsRuntimeImplementedRCW(Type objType)
{
Type t = obj.GetType();
bool isRcw = t.IsCOMObject;
if (t.IsGenericType)
bool isRcw = objType.IsCOMObject;
if (objType.IsGenericType)
{
foreach (var arg in t.GetGenericArguments())
foreach (var arg in objType.GetGenericArguments())
{
if (arg.IsCOMObject)
{
Expand Down Expand Up @@ -313,18 +314,25 @@ protected override void ReleaseObjects(IEnumerable objects)
}
}

unsafe class VtableEntriesCleanupScout
unsafe class VtableEntries
{
private readonly ComInterfaceEntry* _data;
public ComInterfaceEntry* Data { get; }
public int Count { get; }

public VtableEntriesCleanupScout(ComInterfaceEntry* data)
{
_data = data;
public VtableEntries()
{
Data = null;
Count = 0;
}

~VtableEntriesCleanupScout()
{
Marshal.FreeCoTaskMem((IntPtr)_data);
public VtableEntries(List<ComInterfaceEntry> entries, Type type)
{
Data = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(type, sizeof(ComInterfaceEntry) * entries.Count);
for (int i = 0; i < entries.Count; i++)
{
Data[i] = entries[i];
}
Count = entries.Count;
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/WinRT.Runtime/ComWrappersSupport.netstandard2.0.cs
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,10 @@ internal static T FindObject<T>(IntPtr thisPtr)
GCHandle thisHandle = GCHandle.FromIntPtr(unmanagedObject._gchandlePtr);
return (T)thisHandle.Target;
}
}
}

internal class ComCallableWrapper
{
{
private Dictionary<Guid, IntPtr> _managedQITable;
private GCHandle _qiTableHandle;
private volatile IntPtr _strongHandle;
Expand All @@ -304,7 +304,7 @@ public ComCallableWrapper(object obj)
_strongHandle = IntPtr.Zero;
WeakHandle = GCHandle.Alloc(this, GCHandleType.WeakTrackResurrection);
ManagedObject = obj;
var (inspectableInfo, interfaceTableEntries) = ComWrappersSupport.PregenerateNativeTypeInformation(ManagedObject);
var (inspectableInfo, interfaceTableEntries) = ComWrappersSupport.PregenerateNativeTypeInformation(ManagedObject.GetType());

InspectableInfo = inspectableInfo;

Expand Down