diff --git a/src/WinRT.Runtime/ComWrappersSupport.netstandard2.0.cs b/src/WinRT.Runtime/ComWrappersSupport.netstandard2.0.cs index 6377a0e0c..92ec5f02f 100644 --- a/src/WinRT.Runtime/ComWrappersSupport.netstandard2.0.cs +++ b/src/WinRT.Runtime/ComWrappersSupport.netstandard2.0.cs @@ -25,9 +25,14 @@ static partial class ComWrappersSupport private static ConcurrentDictionary> RuntimeWrapperCache = new ConcurrentDictionary>(); private readonly static ConcurrentDictionary> TypeObjectRefFuncCache = new ConcurrentDictionary>(); - internal static InspectableInfo GetInspectableInfo(IntPtr pThis) => UnmanagedObject.FindObject(pThis).InspectableInfo; + internal static InspectableInfo GetInspectableInfo(IntPtr pThis) => UnmanagedObject.FindObject(pThis).InspectableInfo; + + public static T CreateRcwForComObject(IntPtr ptr) + { + return CreateRcwForComObject(ptr, true); + } - public static T CreateRcwForComObject(IntPtr ptr) + private static T CreateRcwForComObject(IntPtr ptr, bool tryUseCache) { if (ptr == IntPtr.Zero) { @@ -47,9 +52,17 @@ public static T CreateRcwForComObject(IntPtr ptr) } else if (identity.TryAs(out var inspectableRef) == 0) { - var inspectable = new IInspectable(identity); - Type runtimeClassType = GetRuntimeClassForTypeCreation(inspectable, typeof(T)); - runtimeWrapper = runtimeClassType == null ? inspectable : TypedObjectFactoryCacheForType.GetOrAdd(runtimeClassType, classType => CreateTypedRcwFactory(classType))(inspectable); + var inspectable = new IInspectable(identity); + + if (typeof(T).IsSealed) + { + runtimeWrapper = TypedObjectFactoryCacheForType.GetOrAdd(typeof(T), classType => CreateTypedRcwFactory(classType))(inspectable); + } + else + { + Type runtimeClassType = GetRuntimeClassForTypeCreation(inspectable, typeof(T)); + runtimeWrapper = runtimeClassType == null ? inspectable : TypedObjectFactoryCacheForType.GetOrAdd(runtimeClassType, classType => CreateTypedRcwFactory(classType))(inspectable); + } } else if (identity.TryAs(out var weakRef) == 0) { @@ -62,17 +75,25 @@ public static T CreateRcwForComObject(IntPtr ptr) return runtimeWrapperReference; }; - RuntimeWrapperCache.AddOrUpdate( - identity.ThisPtr, - rcwFactory, - (ptr, oldValue) => - { - if (!oldValue.TryGetTarget(out keepAliveSentinel)) - { - return rcwFactory(ptr); - } - return oldValue; - }).TryGetTarget(out object rcw); + object rcw; + if (tryUseCache) + { + RuntimeWrapperCache.AddOrUpdate( + identity.ThisPtr, + rcwFactory, + (ptr, oldValue) => + { + if (!oldValue.TryGetTarget(out keepAliveSentinel)) + { + return rcwFactory(ptr); + } + return oldValue; + }).TryGetTarget(out rcw); + } + else + { + rcwFactory(ptr).TryGetTarget(out rcw); + } GC.KeepAlive(keepAliveSentinel); @@ -86,7 +107,9 @@ public static T CreateRcwForComObject(IntPtr ptr) return rcw switch { ABI.System.Nullable nt => (T)nt.Value, - _ => (T)rcw + T castRcw => castRcw, + _ when tryUseCache => CreateRcwForComObject(ptr, false), + _ => throw new ArgumentException(string.Format("Unable to create a wrapper object. The WinRT object {0} has type {1} which cannot be assigned to type {2}", ptr, rcw.GetType(), typeof(T))) }; }