diff --git a/nuget/Microsoft.Windows.CsWinRT.targets b/nuget/Microsoft.Windows.CsWinRT.targets index 90bec30d2..7a5dcf1f8 100644 --- a/nuget/Microsoft.Windows.CsWinRT.targets +++ b/nuget/Microsoft.Windows.CsWinRT.targets @@ -79,7 +79,7 @@ Copyright (C) Microsoft Corporation. All rights reserved. diff --git a/src/Tests/TestComponentCSharp/Class.cpp b/src/Tests/TestComponentCSharp/Class.cpp index 05551fe96..b4d137797 100644 --- a/src/Tests/TestComponentCSharp/Class.cpp +++ b/src/Tests/TestComponentCSharp/Class.cpp @@ -1044,6 +1044,49 @@ namespace winrt::TestComponentCSharp::implementation return winrt::single_threaded_vector_view(std::vector{ *this, *this, *this }); } + IMap Class::GetIntToIntDictionary() + { + return single_threaded_map(std::map{ {1, 4}, { 2, 8 }, { 3, 12 } }); + } + + IMap Class::GetStringToBlittableDictionary() + { + return single_threaded_map(std::map + { + { L"alpha", ComposedBlittableStruct{ 5 } }, + { L"beta", ComposedBlittableStruct{ 4 } }, + { L"charlie", ComposedBlittableStruct{ 7 } } + }); + } + + IMap Class::GetStringToNonBlittableDictionary() + { + return single_threaded_map(std::map + { + { L"String0", ComposedNonBlittableStruct{ { 0 }, { L"String0" }, { true, false, true, false }, { 0 } } }, + { L"String1", ComposedNonBlittableStruct{ { 1 }, { L"String1" }, { false, true, false, true }, { 1 } } }, + { L"String2", ComposedNonBlittableStruct{ { 2 }, { L"String2" }, { true, false, true, false }, { 2 } } } + }); + } + + struct ComposedBlittableStructComparer + { + bool operator() (const TestComponentCSharp::ComposedBlittableStruct& l, const TestComponentCSharp::ComposedBlittableStruct& r) const + { + return (l.blittable.i32 < r.blittable.i32); + } + }; + + IMap Class::GetBlittableToObjectDictionary() + { + return single_threaded_map(std::map + { + { ComposedBlittableStruct{ 1 }, winrt::box_value(0) }, + { ComposedBlittableStruct{ 4 }, winrt::box_value(L"box") }, + { ComposedBlittableStruct{ 8 }, *this } + }); + } + // Test IIDOptimizer IVectorView Class::GetEventArgsVector() { diff --git a/src/Tests/TestComponentCSharp/Class.h b/src/Tests/TestComponentCSharp/Class.h index d8c96f09a..87d440584 100644 --- a/src/Tests/TestComponentCSharp/Class.h +++ b/src/Tests/TestComponentCSharp/Class.h @@ -273,6 +273,11 @@ namespace winrt::TestComponentCSharp::implementation Windows::Foundation::Collections::IVectorView GetObjectVector(); Windows::Foundation::Collections::IVectorView GetInterfaceVector(); Windows::Foundation::Collections::IVectorView GetClassVector() noexcept; + + Windows::Foundation::Collections::IMap GetIntToIntDictionary(); + Windows::Foundation::Collections::IMap GetStringToBlittableDictionary(); + Windows::Foundation::Collections::IMap GetStringToNonBlittableDictionary(); + Windows::Foundation::Collections::IMap GetBlittableToObjectDictionary(); // Test IIDOptimizer -- testing the windows projection covers most code paths, and these two types exercise the rest. Windows::Foundation::Collections::IVectorView GetEventArgsVector(); diff --git a/src/Tests/TestComponentCSharp/TestComponentCSharp.idl b/src/Tests/TestComponentCSharp/TestComponentCSharp.idl index 50ebf076e..a30a24798 100644 --- a/src/Tests/TestComponentCSharp/TestComponentCSharp.idl +++ b/src/Tests/TestComponentCSharp/TestComponentCSharp.idl @@ -315,6 +315,11 @@ namespace TestComponentCSharp Windows.Foundation.Collections.IVectorView GetInterfaceVector(); [noexcept] Windows.Foundation.Collections.IVectorView GetClassVector(); + Windows.Foundation.Collections.IMap GetIntToIntDictionary(); + Windows.Foundation.Collections.IMap GetStringToBlittableDictionary(); + Windows.Foundation.Collections.IMap GetStringToNonBlittableDictionary(); + Windows.Foundation.Collections.IMap GetBlittableToObjectDictionary(); + // Test IIDOptimizer Windows.Foundation.Collections.IVectorView GetEventArgsVector(); Windows.Foundation.Collections.IVectorView GetNonGenericDelegateVector(); diff --git a/src/Tests/UnitTest/TestComponentCSharp_Tests.cs b/src/Tests/UnitTest/TestComponentCSharp_Tests.cs index 348d4c1fe..22d7a5e3d 100644 --- a/src/Tests/UnitTest/TestComponentCSharp_Tests.cs +++ b/src/Tests/UnitTest/TestComponentCSharp_Tests.cs @@ -3069,5 +3069,31 @@ public void TestWeakReferenceEventsFromMultipleContexts() staThread.Start(); staThread.Join(); } + + [Fact] + public void TestDictionary() + { + var intToIntDict = TestObject.GetIntToIntDictionary(); + Assert.Equal(8, intToIntDict[2]); + Assert.Equal(8, intToIntDict[2]); + Assert.Equal(12, intToIntDict[3]); + + var stringToBlittableDict = TestObject.GetStringToBlittableDictionary(); + Assert.Equal(5, stringToBlittableDict["alpha"].blittable.i32); + Assert.Equal(7, stringToBlittableDict["charlie"].blittable.i32); + Assert.Equal(5, stringToBlittableDict["alpha"].blittable.i32); + + var stringToNonBlittableDict = TestObject.GetStringToNonBlittableDictionary(); + Assert.Equal(1, stringToNonBlittableDict["String1"].blittable.i32); + Assert.Equal("String1", stringToNonBlittableDict["String1"].strings.str); + Assert.False(stringToNonBlittableDict["String1"].bools.w); + Assert.True(stringToNonBlittableDict["String1"].bools.x); + + var blittableToObjectDict = TestObject.GetBlittableToObjectDictionary(); + ComposedBlittableStruct key; + key.blittable.i32 = 4; + Assert.Equal("box", (string)blittableToObjectDict[key]); + Assert.Equal("box", (string)blittableToObjectDict[key]); + } } } diff --git a/src/WinRT.Runtime/Projections/IDictionary.net5.cs b/src/WinRT.Runtime/Projections/IDictionary.net5.cs index 74509f9ac..47cd34ac4 100644 --- a/src/WinRT.Runtime/Projections/IDictionary.net5.cs +++ b/src/WinRT.Runtime/Projections/IDictionary.net5.cs @@ -38,12 +38,10 @@ namespace System.Collections.Generic internal sealed class IDictionaryImpl : IDictionary, IWinRTObject { private IObjectReference _inner; - private Dictionary _lookupCache; internal IDictionaryImpl(IObjectReference _inner) { this._inner = _inner; - this._lookupCache = new Dictionary(); } public static IDictionaryImpl CreateRcw(IInspectable obj) => new(obj.ObjRef); @@ -85,7 +83,7 @@ private IObjectReference Make_IEnumerableObjRef() public V this[K key] { - get => ABI.System.Collections.Generic.IDictionaryMethods.Indexer_Get(iDictionaryObjRef, _lookupCache, key); + get => ABI.System.Collections.Generic.IDictionaryMethods.Indexer_Get(iDictionaryObjRef, null, key); set => ABI.System.Collections.Generic.IDictionaryMethods.Indexer_Set(iDictionaryObjRef, key, value); } @@ -114,7 +112,7 @@ public void Clear() public bool Contains(KeyValuePair item) { - return ABI.System.Collections.Generic.IDictionaryMethods.Contains(iDictionaryObjRef, _lookupCache, item); + return ABI.System.Collections.Generic.IDictionaryMethods.Contains(iDictionaryObjRef, null, item); } public bool ContainsKey(K key) @@ -144,7 +142,7 @@ public bool Remove(KeyValuePair item) public bool TryGetValue(K key, [MaybeNullWhen(false)] out V value) { - return ABI.System.Collections.Generic.IDictionaryMethods.TryGetValue(iDictionaryObjRef, _lookupCache, key, out value); + return ABI.System.Collections.Generic.IDictionaryMethods.TryGetValue(iDictionaryObjRef, null, key, out value); } IEnumerator IEnumerable.GetEnumerator() @@ -162,7 +160,7 @@ namespace ABI.Windows.Foundation.Collections internal static class IMapMethods { - public static unsafe V Lookup(IObjectReference obj, Dictionary __lookupCache, K key) + public static unsafe V Lookup(IObjectReference obj, K key) { var _obj = (ObjectReference.Vftbl>)obj; var ThisPtr = _obj.ThisPtr; @@ -174,19 +172,7 @@ public static unsafe V Lookup(IObjectReference obj, Dictionary _ __params[1] = Marshaler.GetAbi(__key); _obj.Vftbl.Lookup_0.DynamicInvokeAbi(__params); - if (__lookupCache != null && __lookupCache.TryGetValue(key, out var __cachedRcw) && __cachedRcw.Item1 == (IntPtr)__params[2]) - { - return __cachedRcw.Item2; - } - else - { - var value = Marshaler.FromAbi(__params[2]); - if (__lookupCache != null) - { - __lookupCache[key] = ((IntPtr)__params[2], value); - } - return value; - } + return Marshaler.FromAbi(__params[2]); } finally { @@ -334,7 +320,7 @@ public static bool Contains(IObjectReference obj, Dictionary __l if (!hasKey) return false; // todo: toctou - V value = IMapMethods.Lookup(obj, __lookupCache, item.Key); + V value = IMapMethods.Lookup(obj, item.Key); return EqualityComparer.Default.Equals(value, item.Value); } @@ -368,7 +354,7 @@ public static V Indexer_Get(IObjectReference obj, Dictionary __l { if (key == null) throw new ArgumentNullException(nameof(key)); - return Lookup(obj, __lookupCache, key); + return Lookup(obj, key); } public static void Indexer_Set(IObjectReference obj, K key, V value) @@ -435,7 +421,7 @@ public static bool TryGetValue(IObjectReference obj, Dictionary try { - value = Lookup(obj, __lookupCache, key); + value = Lookup(obj, key); return true; } catch (KeyNotFoundException) @@ -445,13 +431,13 @@ public static bool TryGetValue(IObjectReference obj, Dictionary } } - private static V Lookup(IObjectReference obj, Dictionary __lookupCache, K key) + private static V Lookup(IObjectReference obj, K key) { Debug.Assert(null != key); try { - return IMapMethods.Lookup(obj, __lookupCache, key); + return IMapMethods.Lookup(obj, key); } catch (global::System.Exception ex) { @@ -1013,7 +999,7 @@ public static ObjectReference ObjRefFromAbi(IntPtr thisPtr) get { var _obj = ((ObjectReference)((IWinRTObject)this).GetObjectReferenceForType(typeof(global::System.Collections.Generic.IDictionary).TypeHandle)); - return IDictionaryMethods.Indexer_Get(_obj, GetLookupCache((IWinRTObject)this), key); + return IDictionaryMethods.Indexer_Get(_obj, null, key); } set { @@ -1040,16 +1026,10 @@ public static ObjectReference ObjRefFromAbi(IntPtr thisPtr) return IDictionaryMethods.Remove(_obj, key); } - internal static global::System.Collections.Generic.Dictionary GetLookupCache(IWinRTObject _this) - { - return (Dictionary)_this.GetOrCreateTypeHelperData(typeof(global::System.Collections.Generic.IDictionary).TypeHandle, - () => new Dictionary()); - } - bool global::System.Collections.Generic.IDictionary.TryGetValue(K key, out V value) { var _obj = ((ObjectReference)((IWinRTObject)this).GetObjectReferenceForType(typeof(global::System.Collections.Generic.IDictionary).TypeHandle)); - return IDictionaryMethods.TryGetValue(_obj, GetLookupCache((IWinRTObject)this), key, out value); + return IDictionaryMethods.TryGetValue(_obj, null, key, out value); } void global::System.Collections.Generic.ICollection>.Add(global::System.Collections.Generic.KeyValuePair item) @@ -1061,7 +1041,7 @@ public static ObjectReference ObjRefFromAbi(IntPtr thisPtr) bool global::System.Collections.Generic.ICollection>.Contains(global::System.Collections.Generic.KeyValuePair item) { var _obj = ((ObjectReference)((IWinRTObject)this).GetObjectReferenceForType(typeof(global::System.Collections.Generic.IDictionary).TypeHandle)); - return IDictionaryMethods.Contains(_obj, GetLookupCache((IWinRTObject)this), item); + return IDictionaryMethods.Contains(_obj, null, item); } void global::System.Collections.Generic.ICollection>.CopyTo(global::System.Collections.Generic.KeyValuePair[] array, int arrayIndex) diff --git a/src/cswinrt/code_writers.h b/src/cswinrt/code_writers.h index 8463bd97e..8f8c0180a 100644 --- a/src/cswinrt/code_writers.h +++ b/src/cswinrt/code_writers.h @@ -1862,7 +1862,7 @@ internal static % Instance => _instance; else { w.write(R"( -internal sealed class _% : IWinRTObject +internal sealed class _% { private IObjectReference _obj; public _%() @@ -1872,23 +1872,6 @@ _obj = %(GuidGenerator.GetIID(typeof(%.%).GetHelperType())); % internal static % Instance => (%)_instance; - -IObjectReference IWinRTObject.NativeObject => _obj; -bool IWinRTObject.HasUnwrappableNativeObject => false; -private volatile global::System.Collections.Concurrent.ConcurrentDictionary _queryInterfaceCache; -private global::System.Collections.Concurrent.ConcurrentDictionary MakeQueryInterfaceCache() -{ - global::System.Threading.Interlocked.CompareExchange(ref _queryInterfaceCache, new global::System.Collections.Concurrent.ConcurrentDictionary(), null); - return _queryInterfaceCache; -} -global::System.Collections.Concurrent.ConcurrentDictionary IWinRTObject.QueryInterfaceCache => _queryInterfaceCache ?? MakeQueryInterfaceCache(); -private volatile global::System.Collections.Concurrent.ConcurrentDictionary _additionalTypeData; -private global::System.Collections.Concurrent.ConcurrentDictionary MakeAdditionalTypeData() -{ - global::System.Threading.Interlocked.CompareExchange(ref _additionalTypeData, new global::System.Collections.Concurrent.ConcurrentDictionary(), null); - return _additionalTypeData; -} -global::System.Collections.Concurrent.ConcurrentDictionary IWinRTObject.AdditionalTypeData => _additionalTypeData ?? MakeAdditionalTypeData(); } )", cache_type_name, @@ -2494,28 +2477,25 @@ IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); auto enumerableObjRefName = std::regex_replace(objref_name, std::regex("IDictionary"), "IEnumerable_global__System_Collections_Generic_KeyValuePair") + "_"; w.write(R"( -private Dictionary<%, (IntPtr, %)> _lookupCache = new Dictionary<%, (IntPtr, %)>(); - %ICollection<%> %Keys => %.get_Keys(%); %ICollection<%> %Values => %.get_Values(%); %int %Count => %.get_Count(%); %bool %IsReadOnly => %.get_IsReadOnly(%); %% %this[% key] { -get => %.Indexer_Get(%, _lookupCache, key); +get => %.Indexer_Get(%, null, key); set => %.Indexer_Set(%, key, value); } %void %Add(% key, % value) => %.Add(%, key, value); %bool %ContainsKey(% key) => %.ContainsKey(%, key); %bool %Remove(% key) => %.Remove(%, key); -%bool %TryGetValue(% key, out % value) => %.TryGetValue(%, _lookupCache, key, out value); +%bool %TryGetValue(% key, out % value) => %.TryGetValue(%, null, key, out value); %void %Add(KeyValuePair<%, %> item) => %.Add(%, item); %void %Clear() => %.Clear(%); -%bool %Contains(KeyValuePair<%, %> item) => %.Contains(%, _lookupCache, item); +%bool %Contains(KeyValuePair<%, %> item) => %.Contains(%, null, item); %void %CopyTo(KeyValuePair<%, %>[] array, int arrayIndex) => %.CopyTo(%, %, array, arrayIndex); bool ICollection>.Remove(KeyValuePair<%, %> item) => %.Remove(%, item); )", -key, value, key, value, visibility, key, self, abiClass, objref_name, //Keys visibility, value, self, abiClass, objref_name, // Values visibility, icollection, abiClass, objref_name, // Count @@ -4095,7 +4075,7 @@ internal static _% Instance => _instance; else { w.write(R"( -internal sealed class _% : IWinRTObject +internal sealed class _% { private IObjectReference _obj; private IntPtr ThisPtr => _obj.ThisPtr; @@ -4106,24 +4086,6 @@ _obj = ActivationFactory<%>.As(GuidGenerator.GetIID(typeof(%.%).GetHelperType()) private static _% _instance = new _%(); internal static _% Instance => _instance; - -IObjectReference IWinRTObject.NativeObject => _obj; -bool IWinRTObject.HasUnwrappableNativeObject => false; -private volatile global::System.Collections.Concurrent.ConcurrentDictionary _queryInterfaceCache; -private global::System.Collections.Concurrent.ConcurrentDictionary MakeQueryInterfaceCache() -{ - global::System.Threading.Interlocked.CompareExchange(ref _queryInterfaceCache, new global::System.Collections.Concurrent.ConcurrentDictionary(), null); - return _queryInterfaceCache; -} -global::System.Collections.Concurrent.ConcurrentDictionary IWinRTObject.QueryInterfaceCache => _queryInterfaceCache ?? MakeQueryInterfaceCache(); -private volatile global::System.Collections.Concurrent.ConcurrentDictionary _additionalTypeData; -private global::System.Collections.Concurrent.ConcurrentDictionary MakeAdditionalTypeData() -{ - global::System.Threading.Interlocked.CompareExchange(ref _additionalTypeData, new global::System.Collections.Concurrent.ConcurrentDictionary(), null); - return _additionalTypeData; -} -global::System.Collections.Concurrent.ConcurrentDictionary IWinRTObject.AdditionalTypeData => _additionalTypeData ?? MakeAdditionalTypeData(); - % } )", diff --git a/src/cswinrt/strings/WinRT.cs b/src/cswinrt/strings/WinRT.cs index 0ceb2ed7b..16bbfc486 100644 --- a/src/cswinrt/strings/WinRT.cs +++ b/src/cswinrt/strings/WinRT.cs @@ -47,14 +47,55 @@ internal static unsafe int CoCreateInstance(ref Guid clsid, IntPtr outer, uint c internal static extern int CoDecrementMTAUsage(IntPtr cookie); [DllImport("api-ms-win-core-com-l1-1-0.dll")] - internal static extern unsafe int CoIncrementMTAUsage(IntPtr* cookie); + internal static extern unsafe int CoIncrementMTAUsage(IntPtr* cookie); + +#if NET6_0_OR_GREATER + internal static bool FreeLibrary(IntPtr moduleHandle) + { + int lastError; + bool returnValue; + int nativeReturnValue; + { + Marshal.SetLastSystemError(0); + nativeReturnValue = PInvoke(moduleHandle); + lastError = Marshal.GetLastSystemError(); + } + // Unmarshal - Convert native data to managed data. + returnValue = nativeReturnValue != 0; + Marshal.SetLastPInvokeError(lastError); + return returnValue; + + // Local P/Invoke + [DllImportAttribute("kernel32.dll", EntryPoint = "FreeLibrary", ExactSpelling = true)] + static extern unsafe int PInvoke(IntPtr nativeModuleHandle); + } + + internal static unsafe void* TryGetProcAddress(IntPtr moduleHandle, sbyte* functionName) + { + int lastError; + void* returnValue; + { + Marshal.SetLastSystemError(0); + returnValue = PInvoke(moduleHandle, functionName); + lastError = Marshal.GetLastSystemError(); + } + + Marshal.SetLastPInvokeError(lastError); + return returnValue; + + // Local P/Invoke + [DllImportAttribute("kernel32.dll", EntryPoint = "GetProcAddress", ExactSpelling = true)] + static extern unsafe void* PInvoke(IntPtr nativeModuleHandle, sbyte* nativeFunctionName); + } +#else [DllImport("kernel32.dll", SetLastError = true)] [return: MarshalAs(UnmanagedType.Bool)] internal static extern bool FreeLibrary(IntPtr moduleHandle); [DllImport("kernel32.dll", EntryPoint = "GetProcAddress", SetLastError = true, BestFitMapping = false)] internal static unsafe extern void* TryGetProcAddress(IntPtr moduleHandle, sbyte* functionName); +#endif internal static unsafe void* TryGetProcAddress(IntPtr moduleHandle, ReadOnlySpan functionName) { @@ -130,17 +171,36 @@ internal static unsafe int CoCreateInstance(ref Guid clsid, IntPtr outer, uint c Marshal.ThrowExceptionForHR(Marshal.GetHRForLastWin32Error(), new IntPtr(-1)); } return functionPtr; - } - + } + +#if NET6_0_OR_GREATER + internal static unsafe IntPtr LoadLibraryExW(ushort* fileName, IntPtr fileHandle, uint flags) + { + int lastError; + IntPtr returnValue; + { + Marshal.SetLastSystemError(0); + returnValue = PInvoke(fileName, fileHandle, flags); + lastError = Marshal.GetLastSystemError(); + } + + Marshal.SetLastPInvokeError(lastError); + return returnValue; + + // Local P/Invoke + [DllImportAttribute("kernel32.dll", EntryPoint = "LoadLibraryExW", ExactSpelling = true)] + static extern unsafe IntPtr PInvoke(ushort* nativeFileName, IntPtr nativeFileHandle, uint nativeFlags); + } +#else [DllImport("kernel32.dll", SetLastError = true)] internal static unsafe extern IntPtr LoadLibraryExW(ushort* fileName, IntPtr fileHandle, uint flags); - +#endif internal static unsafe IntPtr LoadLibraryExW(string fileName, IntPtr fileHandle, uint flags) { fixed (char* lpFileName = fileName) return LoadLibraryExW((ushort*)lpFileName, fileHandle, flags); } - + [DllImport("api-ms-win-core-winrt-l1-1-0.dll")] internal static extern unsafe int RoGetActivationFactory(IntPtr runtimeClassId, Guid* iid, IntPtr* factory);