Skip to content

Commit

Permalink
Don't use exceptions for flow control when loading dlls. (#933)
Browse files Browse the repository at this point in the history
* Don't use exceptions for flow control when loading dlls.

* Fix.

* Add method back.
  • Loading branch information
jlaanstra authored Jul 28, 2021
1 parent 75ffe54 commit c956f4c
Showing 1 changed file with 41 additions and 27 deletions.
68 changes: 41 additions & 27 deletions src/cswinrt/strings/WinRT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ internal class Platform

[DllImport("kernel32.dll", SetLastError = true, BestFitMapping = false)]
internal static extern IntPtr GetProcAddress(IntPtr moduleHandle, [MarshalAs(UnmanagedType.LPStr)] string functionName);

internal static T GetProcAddress<T>(IntPtr moduleHandle)
{
IntPtr functionPtr = Platform.GetProcAddress(moduleHandle, typeof(T).Name);
Expand Down Expand Up @@ -112,46 +111,61 @@ public unsafe delegate int DllGetActivationFactory(

static Dictionary<string, DllModule> _cache = new System.Collections.Generic.Dictionary<string, DllModule>();

public static DllModule Load(string fileName)
public static bool TryLoad(string fileName, out DllModule module)
{
lock (_cache)
{
DllModule module;
if (!_cache.TryGetValue(fileName, out module))
if (_cache.TryGetValue(fileName, out module))
{
return true;
}
else if (TryCreate(fileName, out module))
{
module = new DllModule(fileName);
_cache[fileName] = module;
return true;
}
return module;
return false;
}
}

DllModule(string fileName)
{
_fileName = fileName;

static bool TryCreate(string fileName, out DllModule module)
{
// Explicitly look for module in the same directory as this one, and
// use altered search path to ensure any dependencies in the same directory are found.
_moduleHandle = Platform.LoadLibraryExW(System.IO.Path.Combine(_currentModuleDirectory, fileName), IntPtr.Zero, /* LOAD_WITH_ALTERED_SEARCH_PATH */ 8);
var moduleHandle = Platform.LoadLibraryExW(System.IO.Path.Combine(_currentModuleDirectory, fileName), IntPtr.Zero, /* LOAD_WITH_ALTERED_SEARCH_PATH */ 8);
#if !NETSTANDARD2_0 && !NETCOREAPP2_0
if (_moduleHandle == IntPtr.Zero)
if (moduleHandle == IntPtr.Zero)
{
try
{
// Allow runtime to find module in RID-specific relative subfolder
_moduleHandle = NativeLibrary.Load(fileName, Assembly.GetExecutingAssembly(), null);
}
catch (Exception) { }
NativeLibrary.TryLoad(fileName, Assembly.GetExecutingAssembly(), null, out moduleHandle);
}
#endif
if (_moduleHandle == IntPtr.Zero)
if (moduleHandle == IntPtr.Zero)
{
Marshal.ThrowExceptionForHR(Marshal.GetHRForLastWin32Error());
module = null;
return false;
}

_GetActivationFactory = Platform.GetProcAddress<DllGetActivationFactory>(_moduleHandle);
var getActivationFactory = Platform.GetProcAddress(moduleHandle, nameof(DllGetActivationFactory));
if (getActivationFactory == IntPtr.Zero)
{
module = null;
return false;
}

module = new DllModule(
fileName,
moduleHandle,
Marshal.GetDelegateForFunctionPointer<DllGetActivationFactory>(getActivationFactory));
return true;
}

DllModule(string fileName, IntPtr moduleHandle, DllGetActivationFactory getActivationFactory)
{
_fileName = fileName;
_moduleHandle = moduleHandle;
_GetActivationFactory = getActivationFactory;

var canUnloadNow = Platform.GetProcAddress(_moduleHandle, "DllCanUnloadNow");
var canUnloadNow = Platform.GetProcAddress(_moduleHandle, nameof(DllCanUnloadNow));
if (canUnloadNow != IntPtr.Zero)
{
_CanUnloadNow = Marshal.GetDelegateForFunctionPointer<DllCanUnloadNow>(canUnloadNow);
Expand Down Expand Up @@ -264,12 +278,12 @@ public BaseActivationFactory(string typeNamespace, string typeFullName)
var moduleName = typeNamespace;
while (true)
{
try
{
(_IActivationFactory, _) = DllModule.Load(moduleName + ".dll").GetActivationFactory(typeFullName);
if (_IActivationFactory != null) { return; }
DllModule module = null;
if (DllModule.TryLoad(moduleName + ".dll", out module))
{
(_IActivationFactory, _) = module.GetActivationFactory(typeFullName);
if (_IActivationFactory != null) { return; }
}
catch (Exception) { }

var lastSegment = moduleName.LastIndexOf(".");
if (lastSegment <= 0)
Expand Down

0 comments on commit c956f4c

Please sign in to comment.