From 13206adaf03044ede57c08a1f218c3ec191f9622 Mon Sep 17 00:00:00 2001 From: Manodasan Wignarajah Date: Wed, 15 Mar 2023 02:09:48 -0700 Subject: [PATCH 1/2] Fix issue with weak reference when marshaled across contexts where we can end up treating IUnknown as the IWeakReference interface --- .../UnitTest/TestComponentCSharp_Tests.cs | 44 ++++++++++++++++++- src/WinRT.Runtime/ComWrappersSupport.cs | 25 +++++++++++ src/WinRT.Runtime/ComWrappersSupport.net5.cs | 2 +- 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/Tests/UnitTest/TestComponentCSharp_Tests.cs b/src/Tests/UnitTest/TestComponentCSharp_Tests.cs index 5bcfc8936..a16db983d 100644 --- a/src/Tests/UnitTest/TestComponentCSharp_Tests.cs +++ b/src/Tests/UnitTest/TestComponentCSharp_Tests.cs @@ -9,7 +9,6 @@ using Windows.Foundation; using Windows.UI; -using Windows.Security.Credentials.UI; using Windows.Storage; using Windows.Storage.Streams; using Microsoft.UI.Xaml; @@ -30,6 +29,7 @@ using System.Reflection; using Windows.Devices.Enumeration.Pnp; using System.Diagnostics; +using Windows.Devices.Enumeration; #if NET using WeakRefNS = System; @@ -2936,5 +2936,47 @@ private void TestExperimentAttribute() CustomExperimentClass custom = new CustomExperimentClass(); custom.f(); } + + void OnDeviceAdded(DeviceWatcher sender, DeviceInformation args) + { + } + + void OnDeviceUpdated(DeviceWatcher sender, DeviceInformationUpdate args) + { + } + + [Fact] + public void TestWeakReferenceEventsFromMultipleContexts() + { + SemaphoreSlim semaphore = new SemaphoreSlim(0); + DeviceWatcher watcher = null; + + Thread staThread = new Thread(() => + { + Assert.True(Thread.CurrentThread.GetApartmentState() == ApartmentState.STA); + + watcher = DeviceInformation.CreateWatcher(); + var exception = Record.Exception(() => { + watcher.Added += OnDeviceAdded; + }); + Assert.Null(exception); + + Thread mtaThread = new Thread(() => + { + Assert.True(Thread.CurrentThread.GetApartmentState() == ApartmentState.MTA); + + exception = Record.Exception(() => { + watcher.Updated += OnDeviceUpdated; + }); + Assert.Null(exception); + }); + mtaThread.SetApartmentState(ApartmentState.MTA); + mtaThread.Start(); + mtaThread.Join(); + }); + staThread.SetApartmentState(ApartmentState.STA); + staThread.Start(); + staThread.Join(); + } } } diff --git a/src/WinRT.Runtime/ComWrappersSupport.cs b/src/WinRT.Runtime/ComWrappersSupport.cs index 5ae14f35e..a955bd6e6 100644 --- a/src/WinRT.Runtime/ComWrappersSupport.cs +++ b/src/WinRT.Runtime/ComWrappersSupport.cs @@ -150,6 +150,31 @@ public static ObjectReference GetObjectReferenceForInterface(IntPtr extern } } + internal static ObjectReference GetObjectReferenceForInterfaceWithKnownIID(IntPtr externalComObject, Guid iid) + { + if (externalComObject == IntPtr.Zero) + { + return null; + } + + ObjectReference objRef = ObjectReference.FromAbi(externalComObject); + if (IsFreeThreaded(objRef)) + { + return objRef; + } + else + { + using (objRef) + { + return new ObjectReferenceWithContext( + objRef.GetRef(), + Context.GetContextCallback(), + Context.GetContextToken(), + iid); + } + } + } + public static void RegisterProjectionAssembly(Assembly assembly) => TypeNameSupport.RegisterProjectionAssembly(assembly); public static void RegisterProjectionTypeBaseTypeMapping(IDictionary typeNameToBaseTypeNameMapping) => TypeNameSupport.RegisterProjectionTypeBaseTypeMapping(typeNameToBaseTypeNameMapping); diff --git a/src/WinRT.Runtime/ComWrappersSupport.net5.cs b/src/WinRT.Runtime/ComWrappersSupport.net5.cs index 24a0a6dc4..d1c311b2d 100644 --- a/src/WinRT.Runtime/ComWrappersSupport.net5.cs +++ b/src/WinRT.Runtime/ComWrappersSupport.net5.cs @@ -539,7 +539,7 @@ private static object CreateObject(IntPtr externalComObject) { // IWeakReference is IUnknown-based, so implementations of it may not (and likely won't) implement // IInspectable. As a result, we need to check for them explicitly. - var iunknownObjRef = ComWrappersSupport.GetObjectReferenceForInterface(ptr); + var iunknownObjRef = ComWrappersSupport.GetObjectReferenceForInterfaceWithKnownIID(ptr, weakReferenceIID); ComWrappersHelper.Init(iunknownObjRef); return new SingleInterfaceOptimizedObject(typeof(IWeakReference), iunknownObjRef, false); From ca9db71cf64986985905ee1e905863e96775179f Mon Sep 17 00:00:00 2001 From: Manodasan Wignarajah Date: Wed, 15 Mar 2023 12:11:02 -0700 Subject: [PATCH 2/2] Slight refactor --- src/WinRT.Runtime/ComWrappersSupport.cs | 31 +++++++------------- src/WinRT.Runtime/ComWrappersSupport.net5.cs | 2 +- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/src/WinRT.Runtime/ComWrappersSupport.cs b/src/WinRT.Runtime/ComWrappersSupport.cs index a955bd6e6..dd7d25bb4 100644 --- a/src/WinRT.Runtime/ComWrappersSupport.cs +++ b/src/WinRT.Runtime/ComWrappersSupport.cs @@ -125,39 +125,28 @@ public static ObjectReference GetObjectReferenceForInterface(IntPtr extern } public static ObjectReference GetObjectReferenceForInterface(IntPtr externalComObject, Guid iid) + { + return GetObjectReferenceForInterface(externalComObject, iid, true); + } + + internal static ObjectReference GetObjectReferenceForInterface(IntPtr externalComObject, Guid iid, bool requireQI) { if (externalComObject == IntPtr.Zero) { return null; } - Marshal.ThrowExceptionForHR(Marshal.QueryInterface(externalComObject, ref iid, out IntPtr ptr)); - ObjectReference objRef = ObjectReference.Attach(ref ptr); - if (IsFreeThreaded(objRef)) + ObjectReference objRef; + if (requireQI) { - return objRef; + Marshal.ThrowExceptionForHR(Marshal.QueryInterface(externalComObject, ref iid, out IntPtr ptr)); + objRef = ObjectReference.Attach(ref ptr); } else { - using (objRef) - { - return new ObjectReferenceWithContext( - objRef.GetRef(), - Context.GetContextCallback(), - Context.GetContextToken(), - iid); - } - } - } - - internal static ObjectReference GetObjectReferenceForInterfaceWithKnownIID(IntPtr externalComObject, Guid iid) - { - if (externalComObject == IntPtr.Zero) - { - return null; + objRef = ObjectReference.FromAbi(externalComObject); } - ObjectReference objRef = ObjectReference.FromAbi(externalComObject); if (IsFreeThreaded(objRef)) { return objRef; diff --git a/src/WinRT.Runtime/ComWrappersSupport.net5.cs b/src/WinRT.Runtime/ComWrappersSupport.net5.cs index d1c311b2d..a178e72cd 100644 --- a/src/WinRT.Runtime/ComWrappersSupport.net5.cs +++ b/src/WinRT.Runtime/ComWrappersSupport.net5.cs @@ -539,7 +539,7 @@ private static object CreateObject(IntPtr externalComObject) { // IWeakReference is IUnknown-based, so implementations of it may not (and likely won't) implement // IInspectable. As a result, we need to check for them explicitly. - var iunknownObjRef = ComWrappersSupport.GetObjectReferenceForInterfaceWithKnownIID(ptr, weakReferenceIID); + var iunknownObjRef = ComWrappersSupport.GetObjectReferenceForInterface(ptr, weakReferenceIID, false); ComWrappersHelper.Init(iunknownObjRef); return new SingleInterfaceOptimizedObject(typeof(IWeakReference), iunknownObjRef, false);