diff --git a/src/WinRT.Runtime/ExceptionHelpers.cs b/src/WinRT.Runtime/ExceptionHelpers.cs index d25dc67e9..f486e514b 100644 --- a/src/WinRT.Runtime/ExceptionHelpers.cs +++ b/src/WinRT.Runtime/ExceptionHelpers.cs @@ -6,7 +6,7 @@ namespace WinRT { - public static class ExceptionHelpers + public static unsafe class ExceptionHelpers { private const int COR_E_OBJECTDISPOSED = unchecked((int)0x80131622); private const int RO_E_CLOSED = unchecked((int)0x80000013); @@ -25,58 +25,54 @@ public static class ExceptionHelpers [DllImport("oleaut32.dll")] private static extern int SetErrorInfo(uint dwReserved, IntPtr perrinfo); - internal delegate int GetRestrictedErrorInfo(out IntPtr ppRestrictedErrorInfo); - private static GetRestrictedErrorInfo getRestrictedErrorInfo; + private static delegate* unmanaged[Stdcall] getRestrictedErrorInfo; + private static delegate* unmanaged[Stdcall] setRestrictedErrorInfo; + private static delegate* unmanaged[Stdcall] roOriginateLanguageException; + private static delegate* unmanaged[Stdcall] roReportUnhandledError; - internal delegate int SetRestrictedErrorInfo(IntPtr pRestrictedErrorInfo); - private static SetRestrictedErrorInfo setRestrictedErrorInfo; + private static readonly bool initialized = Initialize(); - internal delegate int RoOriginateLanguageException(int error, IntPtr message, IntPtr langaugeException); - private static RoOriginateLanguageException roOriginateLanguageException; - - internal delegate int RoReportUnhandledError(IntPtr pRestrictedErrorInfo); - private static RoReportUnhandledError roReportUnhandledError; - - static ExceptionHelpers() + private static bool Initialize() { IntPtr winRTErrorModule = Platform.LoadLibraryExW("api-ms-win-core-winrt-error-l1-1-1.dll", IntPtr.Zero, (uint)DllImportSearchPath.System32); if (winRTErrorModule != IntPtr.Zero) { - getRestrictedErrorInfo = Platform.GetProcAddress(winRTErrorModule); - setRestrictedErrorInfo = Platform.GetProcAddress(winRTErrorModule); - roOriginateLanguageException = Platform.GetProcAddress(winRTErrorModule); - roReportUnhandledError = Platform.GetProcAddress(winRTErrorModule); + roOriginateLanguageException = (delegate* unmanaged[Stdcall])Platform.GetProcAddress(winRTErrorModule, "RoOriginateLanguageException"); + roReportUnhandledError = (delegate* unmanaged[Stdcall])Platform.GetProcAddress(winRTErrorModule, "RoReportUnhandledError"); } else { winRTErrorModule = Platform.LoadLibraryExW("api-ms-win-core-winrt-error-l1-1-0.dll", IntPtr.Zero, (uint)DllImportSearchPath.System32); - if (winRTErrorModule != IntPtr.Zero) + } + + if (winRTErrorModule != IntPtr.Zero) + { + getRestrictedErrorInfo = (delegate* unmanaged[Stdcall])Platform.GetProcAddress(winRTErrorModule, "GetRestrictedErrorInfo"); + setRestrictedErrorInfo = (delegate* unmanaged[Stdcall])Platform.GetProcAddress(winRTErrorModule, "SetRestrictedErrorInfo"); + } + + return true; + } + + public static void ThrowExceptionForHR(int hr) + { + if (hr < 0) + { + Throw(hr); + } + + static void Throw(int hr) + { + Exception ex = GetExceptionForHR(hr, useGlobalErrorState: true, out bool restoredExceptionFromGlobalState); + if (restoredExceptionFromGlobalState) + { + ExceptionDispatchInfo.Capture(ex).Throw(); + } + else { - getRestrictedErrorInfo = Platform.GetProcAddress(winRTErrorModule); - setRestrictedErrorInfo = Platform.GetProcAddress(winRTErrorModule); + throw ex; } } - } - - public static void ThrowExceptionForHR(int hr) - { - if (hr < 0) - { - Throw(hr); - } - - static void Throw(int hr) - { - Exception ex = GetExceptionForHR(hr, useGlobalErrorState: true, out bool restoredExceptionFromGlobalState); - if (restoredExceptionFromGlobalState) - { - ExceptionDispatchInfo.Capture(ex).Throw(); - } - else - { - throw ex; - } - } } public static Exception GetExceptionForHR(int hr) => hr >= 0 ? null : GetExceptionForHR(hr, false, out _); diff --git a/src/WinRT.Runtime/Interop/IReferenceTracker.cs b/src/WinRT.Runtime/Interop/IReferenceTracker.cs index e49504a63..845fdc3cf 100644 --- a/src/WinRT.Runtime/Interop/IReferenceTracker.cs +++ b/src/WinRT.Runtime/Interop/IReferenceTracker.cs @@ -4,6 +4,7 @@ namespace WinRT.Interop { [Guid("11D3B13A-180E-4789-A8BE-7712882893E6")] + [StructLayout(LayoutKind.Sequential)] internal unsafe struct IReferenceTrackerVftbl { public global::WinRT.Interop.IUnknownVftbl IUnknownVftbl; diff --git a/src/cswinrt/strings/WinRT.cs b/src/cswinrt/strings/WinRT.cs index b1cdf2354..cca4f6e49 100644 --- a/src/cswinrt/strings/WinRT.cs +++ b/src/cswinrt/strings/WinRT.cs @@ -2,12 +2,11 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Reflection; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Threading; -using System.Linq.Expressions; using System.Diagnostics; using WinRT.Interop; -using System.Runtime.CompilerServices; #pragma warning disable 0169 // The field 'xxx' is never used #pragma warning disable 0649 // Field 'xxx' is never assigned to, and will always have its default value @@ -21,12 +20,6 @@ public static void DynamicInvokeAbi(this System.Delegate del, object[] invoke_pa { Marshal.ThrowExceptionForHR((int)del.DynamicInvoke(invoke_params)); } - - public static T AsDelegate(this MulticastDelegate del) - { - return Marshal.GetDelegateForFunctionPointer( - Marshal.GetFunctionPointerForDelegate(del)); - } } internal class Platform @@ -44,16 +37,16 @@ internal class Platform [return: MarshalAs(UnmanagedType.Bool)] internal static extern bool FreeLibrary(IntPtr moduleHandle); - [DllImport("kernel32.dll", SetLastError = true, BestFitMapping = false)] - internal static extern IntPtr GetProcAddress(IntPtr moduleHandle, [MarshalAs(UnmanagedType.LPStr)] string functionName); - internal static T GetProcAddress(IntPtr moduleHandle) + [DllImport("kernel32.dll", EntryPoint = "GetProcAddress", SetLastError = true, BestFitMapping = false)] + internal static unsafe extern void* TryGetProcAddress(IntPtr moduleHandle, [MarshalAs(UnmanagedType.LPStr)] string functionName); + internal static unsafe void* GetProcAddress(IntPtr moduleHandle, string functionName) { - IntPtr functionPtr = Platform.GetProcAddress(moduleHandle, typeof(T).Name); - if (functionPtr == IntPtr.Zero) + void* functionPtr = Platform.TryGetProcAddress(moduleHandle, functionName); + if (functionPtr == null) { - Marshal.ThrowExceptionForHR(Marshal.GetHRForLastWin32Error()); + Marshal.ThrowExceptionForHR(Marshal.GetHRForLastWin32Error(), new IntPtr(-1)); } - return Marshal.GetDelegateForFunctionPointer(functionPtr); + return functionPtr; } [DllImport("kernel32.dll", SetLastError = true)] @@ -92,20 +85,12 @@ internal struct VftblPtr public IntPtr Vftbl; } - internal class DllModule + internal unsafe class DllModule { - [UnmanagedFunctionPointer(CallingConvention.StdCall)] - public unsafe delegate int DllGetActivationFactory( - IntPtr activatableClassId, - out IntPtr activationFactory); - - [UnmanagedFunctionPointer(CallingConvention.StdCall)] - public unsafe delegate int DllCanUnloadNow(); - readonly string _fileName; readonly IntPtr _moduleHandle; - readonly DllGetActivationFactory _GetActivationFactory; - readonly DllCanUnloadNow _CanUnloadNow; // TODO: Eventually periodically call + readonly delegate* unmanaged[Stdcall] _GetActivationFactory; + readonly delegate* unmanaged[Stdcall] _CanUnloadNow; // TODO: Eventually periodically call static readonly string _currentModuleDirectory = System.IO.Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location); @@ -128,7 +113,7 @@ public static bool TryLoad(string fileName, out DllModule module) } } - static bool TryCreate(string fileName, out DllModule module) + static unsafe bool TryCreate(string fileName, out DllModule module) { // Explicitly look for module in the same directory as this one, and // use altered search path to ensure any dependencies in the same directory are found. @@ -145,8 +130,8 @@ static bool TryCreate(string fileName, out DllModule module) return false; } - var getActivationFactory = Platform.GetProcAddress(moduleHandle, nameof(DllGetActivationFactory)); - if (getActivationFactory == IntPtr.Zero) + var getActivationFactory = Platform.TryGetProcAddress(moduleHandle, "DllGetActivationFactory"); + if (getActivationFactory == null) { module = null; return false; @@ -155,20 +140,20 @@ static bool TryCreate(string fileName, out DllModule module) module = new DllModule( fileName, moduleHandle, - Marshal.GetDelegateForFunctionPointer(getActivationFactory)); + getActivationFactory); return true; } - DllModule(string fileName, IntPtr moduleHandle, DllGetActivationFactory getActivationFactory) + DllModule(string fileName, IntPtr moduleHandle, void* getActivationFactory) { _fileName = fileName; _moduleHandle = moduleHandle; - _GetActivationFactory = getActivationFactory; + _GetActivationFactory = (delegate* unmanaged[Stdcall])getActivationFactory; - var canUnloadNow = Platform.GetProcAddress(_moduleHandle, nameof(DllCanUnloadNow)); - if (canUnloadNow != IntPtr.Zero) + var canUnloadNow = Platform.TryGetProcAddress(_moduleHandle, "DllCanUnloadNow"); + if (canUnloadNow != null) { - _CanUnloadNow = Marshal.GetDelegateForFunctionPointer(canUnloadNow); + _CanUnloadNow = (delegate* unmanaged[Stdcall])canUnloadNow; } } @@ -178,7 +163,7 @@ public unsafe (ObjectReference obj, int hr) GetActivati var hstrRuntimeClassId = MarshalString.CreateMarshaler(runtimeClassId); try { - int hr = _GetActivationFactory(MarshalString.GetAbi(hstrRuntimeClassId), out instancePtr); + int hr = _GetActivationFactory(MarshalString.GetAbi(hstrRuntimeClassId), &instancePtr); return (hr == 0 ? ObjectReference.Attach(ref instancePtr) : null, hr); } finally