Skip to content

Commit

Permalink
Special casing System.Guid for COM VARIANT marshalling (#100377)
Browse files Browse the repository at this point in the history
* Support System.Guid marshalling via VARIANT

VARIANT marshalling in .NET 5+ requires a TLB
for COM records (i.e., ValueType instances). This
means that without a runtime provided TLB, users
must define their own TLB for runtime types or
define their own transfer types.

We address this here by deferring to the NetFX
mscorlib's TLB.

Co-authored-by: Elinor Fung <elfung@microsoft.com>
  • Loading branch information
AaronRobinsonMSFT and elinor-fung authored Apr 5, 2024
1 parent 4ce3525 commit f6237bc
Show file tree
Hide file tree
Showing 27 changed files with 726 additions and 19 deletions.
33 changes: 25 additions & 8 deletions src/coreclr/vm/olevariant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2567,17 +2567,34 @@ void OleVariant::MarshalRecordVariantOleToCom(VARIANT *pOleVariant,
if (!pRecInfo)
COMPlusThrow(kArgumentException, IDS_EE_INVALID_OLE_VARIANT);

LPVOID pvRecord = V_RECORD(pOleVariant);
if (pvRecord == NULL)
{
pComVariant->SetObjRef(NULL);
return;
}

MethodTable* pValueClass = NULL;
{
GCX_PREEMP();
pValueClass = GetMethodTableForRecordInfo(pRecInfo);
}

if (pValueClass == NULL)
{
// This value type should have been registered through
// a TLB. CoreCLR doesn't support dynamic type mapping.
COMPlusThrow(kArgumentException, IDS_EE_CANNOT_MAP_TO_MANAGED_VC);
}
_ASSERTE(pValueClass->IsBlittable());

OBJECTREF BoxedValueClass = NULL;
GCPROTECT_BEGIN(BoxedValueClass)
{
LPVOID pvRecord = V_RECORD(pOleVariant);
if (pvRecord)
{
// This value type should have been registered through
// a TLB. CoreCLR doesn't support dynamic type mapping.
COMPlusThrow(kArgumentException, IDS_EE_CANNOT_MAP_TO_MANAGED_VC);
}

// Now that we have a blittable value class, allocate an instance of the
// boxed value class and copy the contents of the record into it.
BoxedValueClass = AllocateObject(pValueClass);
memcpyNoGCRefs(BoxedValueClass->GetData(), (BYTE*)pvRecord, pValueClass->GetNativeSize());
pComVariant->SetObjRef(BoxedValueClass);
}
GCPROTECT_END();
Expand Down
94 changes: 94 additions & 0 deletions src/coreclr/vm/stdinterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,43 @@ HRESULT GetITypeLibForAssembly(_In_ Assembly *pAssembly, _Outptr_ ITypeLib **ppT
return S_OK;
} // HRESULT GetITypeLibForAssembly()

// .NET Framework's mscorlib TLB GUID.
static const GUID s_MscorlibGuid = { 0xBED7F4EA, 0x1A96, 0x11D2, { 0x8F, 0x08, 0x00, 0xA0, 0xC9, 0xA6, 0x18, 0x6D } };

// Hard-coded GUID for System.Guid.
static const GUID s_GuidForSystemGuid = { 0x9C5923E9, 0xDE52, 0x33EA, { 0x88, 0xDE, 0x7E, 0xBC, 0x86, 0x33, 0xB9, 0xCC } };

// There are types that are helpful to provide that facilitate porting from
// .NET Framework to .NET 8+. This function is used to acquire their ITypeInfo.
// This should be used narrowly. Types at a minimum should be blittable.
static bool TryDeferToMscorlib(MethodTable* pClass, ITypeInfo** ppTI)
{
CONTRACTL
{
THROWS;
GC_TRIGGERS;
MODE_PREEMPTIVE;
PRECONDITION(pClass != NULL);
PRECONDITION(pClass->IsBlittable());
PRECONDITION(ppTI != NULL);
}
CONTRACTL_END;

// Marshalling of System.Guid is a common scenario that impacts many teams porting
// code to .NET 8+. Try to load the .NET Framework's TLB to support this scenario.
if (pClass == CoreLibBinder::GetClass(CLASS__GUID))
{
SafeComHolder<ITypeLib> pMscorlibTypeLib = NULL;
if (SUCCEEDED(::LoadRegTypeLib(s_MscorlibGuid, 2, 4, 0, &pMscorlibTypeLib)))
{
if (SUCCEEDED(pMscorlibTypeLib->GetTypeInfoOfGuid(s_GuidForSystemGuid, ppTI)))
return true;
}
}

return false;
}

HRESULT GetITypeInfoForEEClass(MethodTable *pClass, ITypeInfo **ppTI, bool bClassInfo)
{
CONTRACTL
Expand All @@ -625,6 +662,7 @@ HRESULT GetITypeInfoForEEClass(MethodTable *pClass, ITypeInfo **ppTI, bool bClas
GUID clsid;
GUID ciid;
ComMethodTable *pComMT = NULL;
MethodTable* pOriginalClass = pClass;
HRESULT hr = S_OK;
SafeComHolder<ITypeLib> pITLB = NULL;
SafeComHolder<ITypeInfo> pTI = NULL;
Expand Down Expand Up @@ -770,12 +808,68 @@ HRESULT GetITypeInfoForEEClass(MethodTable *pClass, ITypeInfo **ppTI, bool bClas
{
if (!FAILED(hr))
hr = E_FAIL;

if (pOriginalClass->IsValueType() && pOriginalClass->IsBlittable())
{
if (TryDeferToMscorlib(pOriginalClass, ppTI))
hr = S_OK;
}
}

ReturnHR:
return hr;
} // HRESULT GetITypeInfoForEEClass()

// Only a narrow set of types are supported.
// See TryDeferToMscorlib() above.
MethodTable* GetMethodTableForRecordInfo(IRecordInfo* recInfo)
{
CONTRACTL
{
THROWS;
GC_TRIGGERS;
MODE_PREEMPTIVE;
PRECONDITION(recInfo != NULL);
}
CONTRACTL_END;

HRESULT hr;

// Verify the associated TypeLib attribute
SafeComHolder<ITypeInfo> typeInfo;
hr = recInfo->GetTypeInfo(&typeInfo);
if (FAILED(hr))
return NULL;

SafeComHolder<ITypeLib> typeLib;
UINT index;
hr = typeInfo->GetContainingTypeLib(&typeLib, &index);
if (FAILED(hr))
return NULL;

TLIBATTR* attrs;
hr = typeLib->GetLibAttr(&attrs);
if (FAILED(hr))
return NULL;

GUID libGuid = attrs->guid;
typeLib->ReleaseTLibAttr(attrs);
if (s_MscorlibGuid != libGuid)
return NULL;

// Verify the Guid of the associated type
GUID typeGuid;
hr = recInfo->GetGuid(&typeGuid);
if (FAILED(hr))
return NULL;

// Check for supported types.
if (s_GuidForSystemGuid == typeGuid)
return CoreLibBinder::GetClass(CLASS__GUID);

return NULL;
}

// Returns a NON-ADDREF'd ITypeInfo.
HRESULT GetITypeInfoForMT(ComMethodTable *pMT, ITypeInfo **ppTI)
{
Expand Down
3 changes: 3 additions & 0 deletions src/coreclr/vm/stdinterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,7 @@ IErrorInfo *GetSupportedErrorInfo(IUnknown *iface, REFIID riid);
// Helpers to get the ITypeInfo* for a type.
HRESULT GetITypeInfoForEEClass(MethodTable *pMT, ITypeInfo **ppTI, bool bClassInfo = false);

// Gets the MethodTable for the associated IRecordInfo.
MethodTable* GetMethodTableForRecordInfo(IRecordInfo* recInfo);

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,6 @@ public static unsafe void SetAsByrefVariantIndirect(ref this ComVariant variant,
variant.SetAsByrefVariant(ref value);
return;
case VarEnum.VT_RECORD:
// VT_RECORD's are weird in that regardless of is the VT_BYREF flag is set or not
// they have the same internal representation.
variant = ComVariant.CreateRaw(value.VarType | VarEnum.VT_BYREF, value.GetRawDataRef<Record>());
break;
case VarEnum.VT_DECIMAL:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,11 @@ public static unsafe ComVariant CreateRaw<T>(VarEnum vt, T rawValue)
(VarEnum.VT_UNKNOWN or VarEnum.VT_DISPATCH or VarEnum.VT_LPSTR or VarEnum.VT_BSTR or VarEnum.VT_LPWSTR or VarEnum.VT_SAFEARRAY
or VarEnum.VT_CLSID or VarEnum.VT_STREAM or VarEnum.VT_STREAMED_OBJECT or VarEnum.VT_STORAGE or VarEnum.VT_STORED_OBJECT or VarEnum.VT_CF or VT_VERSIONED_STREAM, _) when sizeof(T) == nint.Size => rawValue,
(VarEnum.VT_CY or VarEnum.VT_FILETIME, 8) => rawValue,
(VarEnum.VT_RECORD, _) when sizeof(T) == sizeof(Record) => rawValue,

// VT_RECORDs are weird in that regardless of whether the VT_BYREF flag is set or not
// they have the same internal representation.
(VarEnum.VT_RECORD or VarEnum.VT_RECORD | VarEnum.VT_BYREF, _) when sizeof(T) == sizeof(Record) => rawValue,

_ when vt.HasFlag(VarEnum.VT_BYREF) && sizeof(T) == nint.Size => rawValue,
_ when vt.HasFlag(VarEnum.VT_VECTOR) && sizeof(T) == sizeof(Vector<byte>) => rawValue,
_ when vt.HasFlag(VarEnum.VT_ARRAY) && sizeof(T) == nint.Size => rawValue,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,40 @@ public void GetNativeVariantForObject_String_Success(string obj)
}
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsBuiltInComEnabled))]
public unsafe void GetNativeVariantForObject_Guid_Success()
{
var guid = new Guid("0DD3E51B-3162-4D13-B906-030F402C5BA2");
var v = new Variant();
IntPtr pNative = Marshal.AllocHGlobal(Marshal.SizeOf(v));
try
{
if (PlatformDetection.IsWindowsNanoServer)
{
Assert.Throws<NotSupportedException>(() => Marshal.GetNativeVariantForObject(guid, pNative));
}
else
{
Marshal.GetNativeVariantForObject(guid, pNative);

Variant result = Marshal.PtrToStructure<Variant>(pNative);
Assert.Equal(VarEnum.VT_RECORD, (VarEnum)result.vt);
Assert.NotEqual(nint.Zero, result.pRecInfo); // We should have an IRecordInfo instance.

var expectedBytes = new ReadOnlySpan<byte>(guid.ToByteArray());
var actualBytes = new ReadOnlySpan<byte>((void*)result.bstrVal, expectedBytes.Length);
Assert.Equal(expectedBytes, actualBytes);

object o = Marshal.GetObjectForNativeVariant(pNative);
Assert.Equal(guid, o);
}
}
finally
{
Marshal.FreeHGlobal(pNative);
}
}

[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsBuiltInComEnabled))]
[InlineData(3.14)]
public unsafe void GetNativeVariantForObject_Double_Success(double obj)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,38 @@ public void GetObjectForNativeVariant_InvalidDate_ThrowsArgumentException(double
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsBuiltInComEnabled))]
public void GetObjectForNativeVariant_NoDataForRecord_ThrowsArgumentException()
public void GetObjectForNativeVariant_NoRecordInfo_ThrowsArgumentException()
{
Variant variant = CreateVariant(VT_RECORD, new UnionTypes { _record = new Record { _recordInfo = IntPtr.Zero } });
AssertExtensions.Throws<ArgumentException>(null, () => GetObjectForNativeVariant(variant));
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsBuiltInComEnabled))]
public void GetObjectForNativeVariant_NoRecordData_ReturnsNull()
{
var recordInfo = new RecordInfo();
IntPtr pRecordInfo = Marshal.GetComInterfaceForObject<RecordInfo, IRecordInfo>(recordInfo);
try
{
Variant variant = CreateVariant(VT_RECORD, new UnionTypes
{
_record = new Record
{
_record = IntPtr.Zero,
_recordInfo = pRecordInfo
}
});
Assert.Null(GetObjectForNativeVariant(variant));
}
finally
{
Marshal.Release(pRecordInfo);
}
}

public static IEnumerable<object[]> GetObjectForNativeVariant_NoSuchGuid_TestData()
{
yield return new object[] { typeof(object).GUID };
yield return new object[] { typeof(string).GUID };
yield return new object[] { Guid.Empty };
}
Expand Down
1 change: 1 addition & 0 deletions src/tests/Interop/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ if(CLR_CMAKE_TARGET_WIN32)
add_subdirectory(COM/NativeClients/DefaultInterfaces)
add_subdirectory(COM/NativeClients/Dispatch)
add_subdirectory(COM/NativeClients/Events)
add_subdirectory(COM/NativeClients/MiscTypes)
add_subdirectory(COM/ComWrappers/MockReferenceTrackerRuntime)
add_subdirectory(COM/ComWrappers/WeakReference)

Expand Down
11 changes: 11 additions & 0 deletions src/tests/Interop/COM/Dynamic/BasicTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public void Run()

String();
Date();
SpecialCasedValueTypes();
ComObject();
Null();

Expand Down Expand Up @@ -385,6 +386,16 @@ private void Date()
Variant<DateTime>(val, expected);
}

private void SpecialCasedValueTypes()
{
{
var val = Guid.NewGuid();
var expected = val;
// Pass as variant
Variant<Guid>(val, expected);
}
}

private void ComObject()
{
Type t = Type.GetTypeFromCLSID(Guid.Parse(ServerGuids.BasicTest));
Expand Down
18 changes: 18 additions & 0 deletions src/tests/Interop/COM/NETClients/MiscTypes/App.manifest
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<?xml version="1.0" encoding="utf-8"?>
<assembly manifestVersion="1.0" xmlns="urn:schemas-microsoft-com:asm.v1">
<assemblyIdentity
type="win32"
name="NetClientMiscTypes"
version="1.0.0.0" />

<dependency>
<dependentAssembly>
<!-- RegFree COM -->
<assemblyIdentity
type="win32"
name="COMNativeServer.X"
version="1.0.0.0"/>
</dependentAssembly>
</dependency>

</assembly>
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<!-- Needed for CMakeProjectReference, GC.WaitForPendingFinalizers -->
<RequiresProcessIsolation>true</RequiresProcessIsolation>
<ApplicationManifest>App.manifest</ApplicationManifest>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>
<ItemGroup>
<Compile Include="Program.cs" />
<Compile Include="../../ServerContracts/Server.CoClasses.cs" />
<Compile Include="../../ServerContracts/Server.Contracts.cs" />
<Compile Include="../../ServerContracts/ServerGuids.cs" />
</ItemGroup>
<ItemGroup>
<CMakeProjectReference Include="../../NativeServer/CMakeLists.txt" />
<ProjectReference Include="$(TestLibraryProjectPath)" />
</ItemGroup>
</Project>
Loading

0 comments on commit f6237bc

Please sign in to comment.