Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose api's for Context on ObjectReference. #1391

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 21 additions & 38 deletions src/WinRT.Runtime/ComWrappersSupport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,15 @@ public static void MarshalDelegateInvoke<T>(IntPtr thisPtr, Action<T> invoke)

// If we are free threaded, we do not need to keep track of context.
// This can either be if the object implements IAgileObject or the free threaded marshaler.
internal unsafe static bool IsFreeThreaded(IObjectReference objRef)
internal unsafe static bool IsFreeThreaded(IntPtr iUnknown)
{
if (objRef.TryAs(InterfaceIIDs.IAgileObject_IID, out var agilePtr) >= 0)
if (Marshal.QueryInterface(iUnknown, ref Unsafe.AsRef(InterfaceIIDs.IAgileObject_IID), out var agilePtr) >= 0)
{
Marshal.Release(agilePtr);
return true;
}
else if (objRef.TryAs(InterfaceIIDs.IMarshal_IID, out var marshalPtr) >= 0)

if (Marshal.QueryInterface(iUnknown, ref Unsafe.AsRef(InterfaceIIDs.IMarshal_IID), out var marshalPtr) >= 0)
{
try
{
Expand All @@ -103,6 +104,14 @@ internal unsafe static bool IsFreeThreaded(IObjectReference objRef)
return false;
}

internal unsafe static bool IsFreeThreaded(IObjectReference objRef)
{
var isFreeThreaded = IsFreeThreaded(objRef.ThisPtr);
// ThisPtr is owned by objRef, so need to make sure objRef stays alive.
GC.KeepAlive(objRef);
return isFreeThreaded;
}

public static IObjectReference GetObjectReferenceForInterface(IntPtr externalComObject)
{
return GetObjectReferenceForInterface<IUnknownVftbl>(externalComObject);
Expand All @@ -115,21 +124,7 @@ public static ObjectReference<T> GetObjectReferenceForInterface<T>(IntPtr extern
return null;
}

ObjectReference<T> objRef = ObjectReference<T>.FromAbi(externalComObject);
if (IsFreeThreaded(objRef))
{
return objRef;
}
else
{
using (objRef)
{
return new ObjectReferenceWithContext<T>(
objRef.GetRef(),
Context.GetContextCallback(),
Context.GetContextToken());
}
}
return ObjectReference<T>.FromAbi(externalComObject);
}

public static ObjectReference<T> GetObjectReferenceForInterface<T>(IntPtr externalComObject, Guid iid)
Expand All @@ -144,31 +139,14 @@ internal static ObjectReference<T> GetObjectReferenceForInterface<T>(IntPtr exte
return null;
}

ObjectReference<T> objRef;
if (requireQI)
{
Marshal.ThrowExceptionForHR(Marshal.QueryInterface(externalComObject, ref iid, out IntPtr ptr));
objRef = ObjectReference<T>.Attach(ref ptr);
return ObjectReference<T>.Attach(ref ptr);
}
else
{
objRef = ObjectReference<T>.FromAbi(externalComObject);
}

if (IsFreeThreaded(objRef))
{
return objRef;
}
else
{
using (objRef)
{
return new ObjectReferenceWithContext<T>(
objRef.GetRef(),
Context.GetContextCallback(),
Context.GetContextToken(),
iid);
}
return ObjectReference<T>.FromAbi(externalComObject);
}
}

Expand Down Expand Up @@ -487,7 +465,12 @@ private static Func<IInspectable, object> CreateCustomTypeMappingFactory(Type cu
}

var fromAbiMethodFunc = (Func<IntPtr, object>) fromAbiMethod.CreateDelegate(typeof(Func<IntPtr, object>));
return (IInspectable obj) => fromAbiMethodFunc(obj.ThisPtr);
return (IInspectable obj) =>
{
var fromAbiMethod = fromAbiMethodFunc(obj.ThisPtr);
GC.KeepAlive(obj);
return fromAbiMethod;
};
}

internal static Func<IInspectable, object> CreateTypedRcwFactory(
Expand Down
8 changes: 3 additions & 5 deletions src/WinRT.Runtime/ComWrappersSupport.net5.cs
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,7 @@ public unsafe static void Init(
// otherwise the new instance will be used. Since the inner was composed
// it should answer immediately without going through the outer. Either way
// the reference count will go to the new instance.
Guid iid = IReferenceTrackerVftbl.IID;
int hr = Marshal.QueryInterface(objRef.ThisPtr, ref iid, out referenceTracker);
int hr = Marshal.QueryInterface(objRef.ThisPtr, ref Unsafe.AsRef(IReferenceTrackerVftbl.IID), out referenceTracker);
if (hr != 0)
{
referenceTracker = default;
Expand Down Expand Up @@ -450,9 +449,8 @@ public unsafe static void Init(
public unsafe static void Init(IObjectReference objRef, bool addRefFromTrackerSource = true)
{
if (objRef.ReferenceTrackerPtr == IntPtr.Zero)
{
Guid iid = IReferenceTrackerVftbl.IID;
int hr = Marshal.QueryInterface(objRef.ThisPtr, ref iid, out var referenceTracker);
{
int hr = Marshal.QueryInterface(objRef.ThisPtr, ref Unsafe.AsRef(IReferenceTrackerVftbl.IID), out var referenceTracker);
if (hr == 0)
{
// WinUI scenario
Expand Down
1 change: 1 addition & 0 deletions src/WinRT.Runtime/ExceptionHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ public static void ReportUnhandledError(Exception ex)
if (restrictedErrorInfoRef != null)
{
roReportUnhandledError(restrictedErrorInfoRef.ThisPtr);
GC.KeepAlive(restrictedErrorInfoRef);
}
}
}
Expand Down
3 changes: 1 addition & 2 deletions src/WinRT.Runtime/Marshalers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1581,8 +1581,7 @@ public static T FromAbi(IntPtr ptr)
IntPtr iunknownPtr = IntPtr.Zero;
try
{
Guid iid_iunknown = IUnknownVftbl.IID;
Marshal.QueryInterface(ptr, ref iid_iunknown, out iunknownPtr);
Marshal.QueryInterface(ptr, ref Unsafe.AsRef(IUnknownVftbl.IID), out iunknownPtr);
if (IUnknownVftbl.IsReferenceToManagedObject(iunknownPtr))
{
return (T)ComWrappersSupport.FindObject<object>(iunknownPtr);
Expand Down
4 changes: 3 additions & 1 deletion src/WinRT.Runtime/MatchingRefApiCompatBaseline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,6 @@ MembersMustExist : Member 'public System.String WinRT.WindowsRuntimeTypeAttribut
TypesMustExist : Type 'WinRT.WinRTExposedTypeAttribute' does not exist in the reference but it does exist in the implementation.
MembersMustExist : Member 'public T ABI.System.Nullable<T>.GetValue(WinRT.IInspectable)' does not exist in the reference but it does exist in the implementation.
TypesMustExist : Type 'WinRT.EventRegistrationTokenTable<T>' does not exist in the reference but it does exist in the implementation.
Total Issues: 132
MembersMustExist : Member 'public System.Boolean WinRT.IObjectReference.IsFreeThreaded.get()' does not exist in the reference but it does exist in the implementation.
MembersMustExist : Member 'public System.Boolean WinRT.IObjectReference.IsInCurrentContext.get()' does not exist in the reference but it does exist in the implementation.
Total Issues: 134
86 changes: 66 additions & 20 deletions src/WinRT.Runtime/ObjectReference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ public IntPtr ThisPtr
}
}

public bool IsFreeThreaded => GetContextToken() == IntPtr.Zero;

public bool IsInCurrentContext
{
get
{
var contextToken = GetContextToken();
return contextToken == IntPtr.Zero || contextToken == Context.GetContextToken();
}
}

private protected IntPtr ThisPtrFromOriginalContext
{
get
Expand Down Expand Up @@ -306,7 +317,7 @@ internal bool Resurrect()
protected virtual unsafe void AddRef(bool refFromTrackerSource)
{
Marshal.AddRef(ThisPtr);
if(refFromTrackerSource)
if (refFromTrackerSource)
{
AddRefFromTrackerSource();
}
Expand Down Expand Up @@ -382,6 +393,11 @@ private protected virtual IntPtr GetThisPtrForCurrentContext()
return ThisPtrFromOriginalContext;
}

private protected virtual IntPtr GetContextToken()
{
return IntPtr.Zero;
}

public ObjectReferenceValue AsValue()
{
// Sharing ptr with objref.
Expand Down Expand Up @@ -423,18 +439,7 @@ public T Vftbl
}
}

public static ObjectReference<T> Attach(ref IntPtr thisPtr)
{
if (thisPtr == IntPtr.Zero)
{
return null;
}
var obj = new ObjectReference<T>(thisPtr);
thisPtr = IntPtr.Zero;
return obj;
}

ObjectReference(IntPtr thisPtr, T vftblT) :
private protected ObjectReference(IntPtr thisPtr, T vftblT) :
base(thisPtr)
{
_vftbl = vftblT;
Expand All @@ -445,15 +450,51 @@ private protected ObjectReference(IntPtr thisPtr) :
{
}

public static unsafe ObjectReference<T> FromAbi(IntPtr thisPtr, T vftblT)
public static ObjectReference<T> Attach(ref IntPtr thisPtr)
{
if (thisPtr == IntPtr.Zero)
{
return null;
}

if (ComWrappersSupport.IsFreeThreaded(thisPtr))
{
var obj = new ObjectReference<T>(thisPtr);
thisPtr = IntPtr.Zero;
return obj;
}
else
{
var obj = new ObjectReferenceWithContext<T>(
thisPtr,
Context.GetContextCallback(),
Context.GetContextToken());
thisPtr = IntPtr.Zero;
return obj;
}
}

public static unsafe ObjectReference<T> FromAbi(IntPtr thisPtr, T vftblT)
{
if (thisPtr == IntPtr.Zero)
{
return null;
}

Marshal.AddRef(thisPtr);
var obj = new ObjectReference<T>(thisPtr, vftblT);
return obj;
if (ComWrappersSupport.IsFreeThreaded(thisPtr))
{
var obj = new ObjectReference<T>(thisPtr, vftblT);
return obj;
}
else
{
var obj = new ObjectReferenceWithContext<T>(
thisPtr,
Context.GetContextCallback(),
Context.GetContextToken());
return obj;
}
}

public static ObjectReference<T> FromAbi(IntPtr thisPtr)
Expand Down Expand Up @@ -504,7 +545,7 @@ internal sealed class ObjectReferenceWithContext<
#endif
T> : ObjectReference<T>
{
private readonly IntPtr _contextCallbackPtr;
private readonly IntPtr _contextCallbackPtr;
private readonly IntPtr _contextToken;

private volatile ConcurrentDictionary<IntPtr, ObjectReference<T>> __cachedContext;
Expand All @@ -520,7 +561,7 @@ private ConcurrentDictionary<IntPtr, ObjectReference<T>> Make_CachedContext()
private volatile AgileReference __agileReference;
private AgileReference AgileReference => _isAgileReferenceSet ? __agileReference : Make_AgileReference();
private AgileReference Make_AgileReference()
{
{
Context.CallInContext(_contextCallbackPtr, _contextToken, InitAgileReference, null);

// Set after CallInContext callback given callback can fail to occur.
Expand All @@ -536,9 +577,9 @@ void InitAgileReference()
private readonly Guid _iid;

internal ObjectReferenceWithContext(IntPtr thisPtr, IntPtr contextCallbackPtr, IntPtr contextToken)
:base(thisPtr)
: base(thisPtr)
{
_contextCallbackPtr = contextCallbackPtr;
_contextCallbackPtr = contextCallbackPtr;
_contextToken = contextToken;
}

Expand All @@ -559,6 +600,11 @@ private protected override IntPtr GetThisPtrForCurrentContext()
return cachedObjRef.ThisPtr;
}

private protected override IntPtr GetContextToken()
{
return this._contextToken;
}

private protected override T GetVftblForCurrentContext()
{
ObjectReference<T> cachedObjRef = GetCurrentContext();
Expand Down
12 changes: 6 additions & 6 deletions src/cswinrt/code_writers.h
Original file line number Diff line number Diff line change
Expand Up @@ -1935,13 +1935,13 @@ private static % _% = new %("%.%", %.IID);
{
auto objrefname = w.write_temp("%", bind<write_objref_type_name>(classType));
w.write(R"(
private static volatile FactoryObjectReference<IActivationFactoryVftbl> __%;
private static FactoryObjectReference<IActivationFactoryVftbl> %
private static volatile ObjectReference<IActivationFactoryVftbl> __%;
private static ObjectReference<IActivationFactoryVftbl> %
{
get
{
var factory = __%;
if (factory != null && factory.IsObjectInContext())
if (factory != null && factory.IsInCurrentContext)
{
return factory;
}
Expand Down Expand Up @@ -1993,13 +1993,13 @@ private static ObjectReference<%> % => __% ?? Make__%();
{
auto objrefname = w.write_temp("%", bind<write_objref_type_name>(staticsType));
w.write(R"(
private static volatile FactoryObjectReference<%> __%;
private static FactoryObjectReference<%> %
private static volatile ObjectReference<%> __%;
private static ObjectReference<%> %
{
get
{
var factory = __%;
if (factory != null && factory.IsObjectInContext())
if (factory != null && factory.IsInCurrentContext)
{
return factory;
}
Expand Down
Loading