Skip to content

Commit

Permalink
Convert activation and error info delegates to function pointers (#1018)
Browse files Browse the repository at this point in the history
* Convert activation and error info delegates to function pointers

Avoids overhead of creating the marshaled delegates

* CR feedback

Co-authored-by: Manodasan Wignarajah <mawign@microsoft.com>
  • Loading branch information
jkotas and manodasanW authored Oct 11, 2021
1 parent a32c457 commit 4035ab8
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 76 deletions.
76 changes: 36 additions & 40 deletions src/WinRT.Runtime/ExceptionHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace WinRT
{
public static class ExceptionHelpers
public static unsafe class ExceptionHelpers
{
private const int COR_E_OBJECTDISPOSED = unchecked((int)0x80131622);
private const int RO_E_CLOSED = unchecked((int)0x80000013);
Expand All @@ -25,58 +25,54 @@ public static class ExceptionHelpers
[DllImport("oleaut32.dll")]
private static extern int SetErrorInfo(uint dwReserved, IntPtr perrinfo);

internal delegate int GetRestrictedErrorInfo(out IntPtr ppRestrictedErrorInfo);
private static GetRestrictedErrorInfo getRestrictedErrorInfo;
private static delegate* unmanaged[Stdcall]<out IntPtr, int> getRestrictedErrorInfo;
private static delegate* unmanaged[Stdcall]<IntPtr, int> setRestrictedErrorInfo;
private static delegate* unmanaged[Stdcall]<int, IntPtr, IntPtr, int> roOriginateLanguageException;
private static delegate* unmanaged[Stdcall]<IntPtr, int> roReportUnhandledError;

internal delegate int SetRestrictedErrorInfo(IntPtr pRestrictedErrorInfo);
private static SetRestrictedErrorInfo setRestrictedErrorInfo;
private static readonly bool initialized = Initialize();

internal delegate int RoOriginateLanguageException(int error, IntPtr message, IntPtr langaugeException);
private static RoOriginateLanguageException roOriginateLanguageException;

internal delegate int RoReportUnhandledError(IntPtr pRestrictedErrorInfo);
private static RoReportUnhandledError roReportUnhandledError;

static ExceptionHelpers()
private static bool Initialize()
{
IntPtr winRTErrorModule = Platform.LoadLibraryExW("api-ms-win-core-winrt-error-l1-1-1.dll", IntPtr.Zero, (uint)DllImportSearchPath.System32);
if (winRTErrorModule != IntPtr.Zero)
{
getRestrictedErrorInfo = Platform.GetProcAddress<GetRestrictedErrorInfo>(winRTErrorModule);
setRestrictedErrorInfo = Platform.GetProcAddress<SetRestrictedErrorInfo>(winRTErrorModule);
roOriginateLanguageException = Platform.GetProcAddress<RoOriginateLanguageException>(winRTErrorModule);
roReportUnhandledError = Platform.GetProcAddress<RoReportUnhandledError>(winRTErrorModule);
roOriginateLanguageException = (delegate* unmanaged[Stdcall]<int, IntPtr, IntPtr, int>)Platform.GetProcAddress(winRTErrorModule, "RoOriginateLanguageException");
roReportUnhandledError = (delegate* unmanaged[Stdcall]<IntPtr, int>)Platform.GetProcAddress(winRTErrorModule, "RoReportUnhandledError");
}
else
{
winRTErrorModule = Platform.LoadLibraryExW("api-ms-win-core-winrt-error-l1-1-0.dll", IntPtr.Zero, (uint)DllImportSearchPath.System32);
if (winRTErrorModule != IntPtr.Zero)
}

if (winRTErrorModule != IntPtr.Zero)
{
getRestrictedErrorInfo = (delegate* unmanaged[Stdcall]<out IntPtr, int>)Platform.GetProcAddress(winRTErrorModule, "GetRestrictedErrorInfo");
setRestrictedErrorInfo = (delegate* unmanaged[Stdcall]<IntPtr, int>)Platform.GetProcAddress(winRTErrorModule, "SetRestrictedErrorInfo");
}

return true;
}

public static void ThrowExceptionForHR(int hr)
{
if (hr < 0)
{
Throw(hr);
}

static void Throw(int hr)
{
Exception ex = GetExceptionForHR(hr, useGlobalErrorState: true, out bool restoredExceptionFromGlobalState);
if (restoredExceptionFromGlobalState)
{
ExceptionDispatchInfo.Capture(ex).Throw();
}
else
{
getRestrictedErrorInfo = Platform.GetProcAddress<GetRestrictedErrorInfo>(winRTErrorModule);
setRestrictedErrorInfo = Platform.GetProcAddress<SetRestrictedErrorInfo>(winRTErrorModule);
throw ex;
}
}
}

public static void ThrowExceptionForHR(int hr)
{
if (hr < 0)
{
Throw(hr);
}

static void Throw(int hr)
{
Exception ex = GetExceptionForHR(hr, useGlobalErrorState: true, out bool restoredExceptionFromGlobalState);
if (restoredExceptionFromGlobalState)
{
ExceptionDispatchInfo.Capture(ex).Throw();
}
else
{
throw ex;
}
}
}

public static Exception GetExceptionForHR(int hr) => hr >= 0 ? null : GetExceptionForHR(hr, false, out _);
Expand Down
1 change: 1 addition & 0 deletions src/WinRT.Runtime/Interop/IReferenceTracker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
namespace WinRT.Interop
{
[Guid("11D3B13A-180E-4789-A8BE-7712882893E6")]
[StructLayout(LayoutKind.Sequential)]
internal unsafe struct IReferenceTrackerVftbl
{
public global::WinRT.Interop.IUnknownVftbl IUnknownVftbl;
Expand Down
57 changes: 21 additions & 36 deletions src/cswinrt/strings/WinRT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Linq.Expressions;
using System.Diagnostics;
using WinRT.Interop;
using System.Runtime.CompilerServices;

#pragma warning disable 0169 // The field 'xxx' is never used
#pragma warning disable 0649 // Field 'xxx' is never assigned to, and will always have its default value
Expand All @@ -21,12 +20,6 @@ public static void DynamicInvokeAbi(this System.Delegate del, object[] invoke_pa
{
Marshal.ThrowExceptionForHR((int)del.DynamicInvoke(invoke_params));
}

public static T AsDelegate<T>(this MulticastDelegate del)
{
return Marshal.GetDelegateForFunctionPointer<T>(
Marshal.GetFunctionPointerForDelegate(del));
}
}

internal class Platform
Expand All @@ -44,16 +37,16 @@ internal class Platform
[return: MarshalAs(UnmanagedType.Bool)]
internal static extern bool FreeLibrary(IntPtr moduleHandle);

[DllImport("kernel32.dll", SetLastError = true, BestFitMapping = false)]
internal static extern IntPtr GetProcAddress(IntPtr moduleHandle, [MarshalAs(UnmanagedType.LPStr)] string functionName);
internal static T GetProcAddress<T>(IntPtr moduleHandle)
[DllImport("kernel32.dll", EntryPoint = "GetProcAddress", SetLastError = true, BestFitMapping = false)]
internal static unsafe extern void* TryGetProcAddress(IntPtr moduleHandle, [MarshalAs(UnmanagedType.LPStr)] string functionName);
internal static unsafe void* GetProcAddress(IntPtr moduleHandle, string functionName)
{
IntPtr functionPtr = Platform.GetProcAddress(moduleHandle, typeof(T).Name);
if (functionPtr == IntPtr.Zero)
void* functionPtr = Platform.TryGetProcAddress(moduleHandle, functionName);
if (functionPtr == null)
{
Marshal.ThrowExceptionForHR(Marshal.GetHRForLastWin32Error());
Marshal.ThrowExceptionForHR(Marshal.GetHRForLastWin32Error(), new IntPtr(-1));
}
return Marshal.GetDelegateForFunctionPointer<T>(functionPtr);
return functionPtr;
}

[DllImport("kernel32.dll", SetLastError = true)]
Expand Down Expand Up @@ -92,20 +85,12 @@ internal struct VftblPtr
public IntPtr Vftbl;
}

internal class DllModule
internal unsafe class DllModule
{
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public unsafe delegate int DllGetActivationFactory(
IntPtr activatableClassId,
out IntPtr activationFactory);

[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public unsafe delegate int DllCanUnloadNow();

readonly string _fileName;
readonly IntPtr _moduleHandle;
readonly DllGetActivationFactory _GetActivationFactory;
readonly DllCanUnloadNow _CanUnloadNow; // TODO: Eventually periodically call
readonly delegate* unmanaged[Stdcall]<IntPtr, IntPtr*, int> _GetActivationFactory;
readonly delegate* unmanaged[Stdcall]<int> _CanUnloadNow; // TODO: Eventually periodically call

static readonly string _currentModuleDirectory = System.IO.Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location);

Expand All @@ -128,7 +113,7 @@ public static bool TryLoad(string fileName, out DllModule module)
}
}

static bool TryCreate(string fileName, out DllModule module)
static unsafe bool TryCreate(string fileName, out DllModule module)
{
// Explicitly look for module in the same directory as this one, and
// use altered search path to ensure any dependencies in the same directory are found.
Expand All @@ -145,8 +130,8 @@ static bool TryCreate(string fileName, out DllModule module)
return false;
}

var getActivationFactory = Platform.GetProcAddress(moduleHandle, nameof(DllGetActivationFactory));
if (getActivationFactory == IntPtr.Zero)
var getActivationFactory = Platform.TryGetProcAddress(moduleHandle, "DllGetActivationFactory");
if (getActivationFactory == null)
{
module = null;
return false;
Expand All @@ -155,20 +140,20 @@ static bool TryCreate(string fileName, out DllModule module)
module = new DllModule(
fileName,
moduleHandle,
Marshal.GetDelegateForFunctionPointer<DllGetActivationFactory>(getActivationFactory));
getActivationFactory);
return true;
}

DllModule(string fileName, IntPtr moduleHandle, DllGetActivationFactory getActivationFactory)
DllModule(string fileName, IntPtr moduleHandle, void* getActivationFactory)
{
_fileName = fileName;
_moduleHandle = moduleHandle;
_GetActivationFactory = getActivationFactory;
_GetActivationFactory = (delegate* unmanaged[Stdcall]<IntPtr, IntPtr*, int>)getActivationFactory;

var canUnloadNow = Platform.GetProcAddress(_moduleHandle, nameof(DllCanUnloadNow));
if (canUnloadNow != IntPtr.Zero)
var canUnloadNow = Platform.TryGetProcAddress(_moduleHandle, "DllCanUnloadNow");
if (canUnloadNow != null)
{
_CanUnloadNow = Marshal.GetDelegateForFunctionPointer<DllCanUnloadNow>(canUnloadNow);
_CanUnloadNow = (delegate* unmanaged[Stdcall]<int>)canUnloadNow;
}
}

Expand All @@ -178,7 +163,7 @@ public unsafe (ObjectReference<IActivationFactoryVftbl> obj, int hr) GetActivati
var hstrRuntimeClassId = MarshalString.CreateMarshaler(runtimeClassId);
try
{
int hr = _GetActivationFactory(MarshalString.GetAbi(hstrRuntimeClassId), out instancePtr);
int hr = _GetActivationFactory(MarshalString.GetAbi(hstrRuntimeClassId), &instancePtr);
return (hr == 0 ? ObjectReference<IActivationFactoryVftbl>.Attach(ref instancePtr) : null, hr);
}
finally
Expand Down

0 comments on commit 4035ab8

Please sign in to comment.