From 5c60130d029f7583ce4189a6801554f1a205e8fa Mon Sep 17 00:00:00 2001 From: Manodasan Wignarajah Date: Wed, 15 Mar 2023 17:14:31 -0700 Subject: [PATCH] Fix resolving IWeakReference from different context (#1301) * Fix issue with weak reference when marshaled across contexts where we can end up treating IUnknown as the IWeakReference interface * Slight refactor --- .../UnitTest/TestComponentCSharp_Tests.cs | 44 ++++++++++++++++++- src/WinRT.Runtime/ComWrappersSupport.cs | 18 +++++++- src/WinRT.Runtime/ComWrappersSupport.net5.cs | 2 +- 3 files changed, 60 insertions(+), 4 deletions(-) diff --git a/src/Tests/UnitTest/TestComponentCSharp_Tests.cs b/src/Tests/UnitTest/TestComponentCSharp_Tests.cs index 5bcfc8936..a16db983d 100644 --- a/src/Tests/UnitTest/TestComponentCSharp_Tests.cs +++ b/src/Tests/UnitTest/TestComponentCSharp_Tests.cs @@ -9,7 +9,6 @@ using Windows.Foundation; using Windows.UI; -using Windows.Security.Credentials.UI; using Windows.Storage; using Windows.Storage.Streams; using Microsoft.UI.Xaml; @@ -30,6 +29,7 @@ using System.Reflection; using Windows.Devices.Enumeration.Pnp; using System.Diagnostics; +using Windows.Devices.Enumeration; #if NET using WeakRefNS = System; @@ -2936,5 +2936,47 @@ private void TestExperimentAttribute() CustomExperimentClass custom = new CustomExperimentClass(); custom.f(); } + + void OnDeviceAdded(DeviceWatcher sender, DeviceInformation args) + { + } + + void OnDeviceUpdated(DeviceWatcher sender, DeviceInformationUpdate args) + { + } + + [Fact] + public void TestWeakReferenceEventsFromMultipleContexts() + { + SemaphoreSlim semaphore = new SemaphoreSlim(0); + DeviceWatcher watcher = null; + + Thread staThread = new Thread(() => + { + Assert.True(Thread.CurrentThread.GetApartmentState() == ApartmentState.STA); + + watcher = DeviceInformation.CreateWatcher(); + var exception = Record.Exception(() => { + watcher.Added += OnDeviceAdded; + }); + Assert.Null(exception); + + Thread mtaThread = new Thread(() => + { + Assert.True(Thread.CurrentThread.GetApartmentState() == ApartmentState.MTA); + + exception = Record.Exception(() => { + watcher.Updated += OnDeviceUpdated; + }); + Assert.Null(exception); + }); + mtaThread.SetApartmentState(ApartmentState.MTA); + mtaThread.Start(); + mtaThread.Join(); + }); + staThread.SetApartmentState(ApartmentState.STA); + staThread.Start(); + staThread.Join(); + } } } diff --git a/src/WinRT.Runtime/ComWrappersSupport.cs b/src/WinRT.Runtime/ComWrappersSupport.cs index 726605e34..fc2f6bf33 100644 --- a/src/WinRT.Runtime/ComWrappersSupport.cs +++ b/src/WinRT.Runtime/ComWrappersSupport.cs @@ -125,14 +125,28 @@ public static ObjectReference GetObjectReferenceForInterface(IntPtr extern } public static ObjectReference GetObjectReferenceForInterface(IntPtr externalComObject, Guid iid) + { + return GetObjectReferenceForInterface(externalComObject, iid, true); + } + + internal static ObjectReference GetObjectReferenceForInterface(IntPtr externalComObject, Guid iid, bool requireQI) { if (externalComObject == IntPtr.Zero) { return null; } - Marshal.ThrowExceptionForHR(Marshal.QueryInterface(externalComObject, ref iid, out IntPtr ptr)); - ObjectReference objRef = ObjectReference.Attach(ref ptr); + ObjectReference objRef; + if (requireQI) + { + Marshal.ThrowExceptionForHR(Marshal.QueryInterface(externalComObject, ref iid, out IntPtr ptr)); + objRef = ObjectReference.Attach(ref ptr); + } + else + { + objRef = ObjectReference.FromAbi(externalComObject); + } + if (IsFreeThreaded(objRef)) { return objRef; diff --git a/src/WinRT.Runtime/ComWrappersSupport.net5.cs b/src/WinRT.Runtime/ComWrappersSupport.net5.cs index 24a0a6dc4..a178e72cd 100644 --- a/src/WinRT.Runtime/ComWrappersSupport.net5.cs +++ b/src/WinRT.Runtime/ComWrappersSupport.net5.cs @@ -539,7 +539,7 @@ private static object CreateObject(IntPtr externalComObject) { // IWeakReference is IUnknown-based, so implementations of it may not (and likely won't) implement // IInspectable. As a result, we need to check for them explicitly. - var iunknownObjRef = ComWrappersSupport.GetObjectReferenceForInterface(ptr); + var iunknownObjRef = ComWrappersSupport.GetObjectReferenceForInterface(ptr, weakReferenceIID, false); ComWrappersHelper.Init(iunknownObjRef); return new SingleInterfaceOptimizedObject(typeof(IWeakReference), iunknownObjRef, false);