diff --git a/sycl/include/sycl/detail/os_util.hpp b/sycl/include/sycl/detail/os_util.hpp index 4ab08bb284937..fa55425cb2505 100644 --- a/sycl/include/sycl/detail/os_util.hpp +++ b/sycl/include/sycl/detail/os_util.hpp @@ -15,6 +15,7 @@ #include // for size_t #include // for string #include // for stat +#include #ifdef _WIN32 #define __SYCL_RT_OS_WINDOWS @@ -49,6 +50,9 @@ class __SYCL_EXPORT OSUtil { /// Returns a directory component of a path. static std::string getDirName(const char *Path); + /// Returns an absolute path to a directory where the object was found. + static std::filesystem::path getCurrentDSODirPath(); + #ifdef __SYCL_RT_OS_WINDOWS static constexpr const char *DirSep = "\\"; #else diff --git a/sycl/include/sycl/detail/pi.hpp b/sycl/include/sycl/detail/pi.hpp index c4ddfe9fd1b44..380db95aa42b9 100644 --- a/sycl/include/sycl/detail/pi.hpp +++ b/sycl/include/sycl/detail/pi.hpp @@ -25,6 +25,7 @@ #include // for char_traits, string #include // for false_type, true_type #include // for vector +#include #ifdef XPTI_ENABLE_INSTRUMENTATION // Forward declarations @@ -171,7 +172,7 @@ __SYCL_EXPORT void contextSetExtendedDeleter(const sycl::context &constext, // Function to load a shared library // Implementation is OS dependent -void *loadOsLibrary(const std::string &Library); +void *loadOsLibrary(const std::filesystem::path &Library); // Function to unload a shared library // Implementation is OS dependent (see posix-pi.cpp and windows-pi.cpp) @@ -180,7 +181,7 @@ int unloadOsLibrary(void *Library); // Function to load the shared plugin library // On Windows, this will have been pre-loaded by proxy loader. // Implementation is OS dependent. -void *loadOsPluginLibrary(const std::string &Library); +void *loadOsPluginLibrary(const std::filesystem::path &Library); // Function to unload the shared plugin library // Implementation is OS dependent (see posix-pi.cpp and windows-pi.cpp) diff --git a/sycl/pi_win_proxy_loader/pi_win_proxy_loader.cpp b/sycl/pi_win_proxy_loader/pi_win_proxy_loader.cpp index fcc5e49c15344..e0d0037cf962b 100644 --- a/sycl/pi_win_proxy_loader/pi_win_proxy_loader.cpp +++ b/sycl/pi_win_proxy_loader/pi_win_proxy_loader.cpp @@ -23,6 +23,7 @@ // similar approach. #include +#include #ifdef _WIN32 @@ -99,6 +100,25 @@ std::string getCurrentDSODir() { return Path; } +std::filesystem::path getCurrentDSODirPath() { + wchar_t Path[MAX_PATH]; + //Path[0] = '\0'; + //Path[sizeof(Path) - 1] = '\0'; + auto Handle = getOSModuleHandle(reinterpret_cast(&getCurrentDSODir)); + DWORD Ret = GetModuleFileName( + reinterpret_cast(ExeModuleHandle == Handle ? 0 : Handle), + reinterpret_cast(&Path), sizeof(Path)); + assert(Ret < sizeof(Path) && "Path is longer than PATH_MAX?"); + assert(Ret > 0 && "GetModuleFileName failed"); + (void)Ret; + + BOOL RetCode = PathRemoveFileSpec(reinterpret_cast(&Path)); + assert(RetCode && "PathRemoveFileSpec failed"); + (void)RetCode; + + return std::filesystem::path(std::wstring(Path)); +} + // these are cribbed from include/sycl/detail/pi.hpp // a new plugin must be added to both places. #ifdef _MSC_VER @@ -121,7 +141,7 @@ std::string getCurrentDSODir() { // ------------------------------------ -using MapT = std::map; +using MapT = std::map; MapT &getDllMap() { static MapT dllMap; @@ -141,46 +161,46 @@ void preloadLibraries() { // UINT SavedMode = SetErrorMode(SEM_FAILCRITICALERRORS); // Exclude current directory from DLL search path - if (!SetDllDirectoryA("")) { + if (!SetDllDirectory(L"")) { assert(false && "Failed to update DLL search path"); } // this path duplicates sycl/detail/pi.cpp:initializePlugins - const std::string LibSYCLDir = getCurrentDSODir() + DirSep; + std::filesystem::path LibSYCLDir = getCurrentDSODirPath(); MapT &dllMap = getDllMap(); + + auto ocl_path = LibSYCLDir / __SYCL_OPENCL_PLUGIN_NAME; + dllMap.emplace(ocl_path, LoadLibrary(ocl_path.wstring().c_str())); - std::string ocl_path = LibSYCLDir + __SYCL_OPENCL_PLUGIN_NAME; - dllMap.emplace(ocl_path, LoadLibraryA(ocl_path.c_str())); - - std::string l0_path = LibSYCLDir + __SYCL_LEVEL_ZERO_PLUGIN_NAME; - dllMap.emplace(l0_path, LoadLibraryA(l0_path.c_str())); + auto l0_path = LibSYCLDir / __SYCL_LEVEL_ZERO_PLUGIN_NAME; + dllMap.emplace(l0_path, LoadLibrary(l0_path.wstring().c_str())); - std::string cuda_path = LibSYCLDir + __SYCL_CUDA_PLUGIN_NAME; - dllMap.emplace(cuda_path, LoadLibraryA(cuda_path.c_str())); + auto cuda_path = LibSYCLDir / __SYCL_CUDA_PLUGIN_NAME; + dllMap.emplace(cuda_path, LoadLibrary(cuda_path.wstring().c_str())); - std::string esimd_path = LibSYCLDir + __SYCL_ESIMD_EMULATOR_PLUGIN_NAME; - dllMap.emplace(esimd_path, LoadLibraryA(esimd_path.c_str())); + auto esimd_path = LibSYCLDir / __SYCL_ESIMD_EMULATOR_PLUGIN_NAME; + dllMap.emplace(esimd_path, LoadLibrary(esimd_path.wstring().c_str())); - std::string hip_path = LibSYCLDir + __SYCL_HIP_PLUGIN_NAME; - dllMap.emplace(hip_path, LoadLibraryA(hip_path.c_str())); + auto hip_path = LibSYCLDir / __SYCL_HIP_PLUGIN_NAME; + dllMap.emplace(hip_path, LoadLibrary(hip_path.wstring().c_str())); - std::string ur_path = LibSYCLDir + __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME; - dllMap.emplace(ur_path, LoadLibraryA(ur_path.c_str())); + auto ur_path = LibSYCLDir / __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME; + dllMap.emplace(ur_path, LoadLibrary(ur_path.wstring().c_str())); - std::string nativecpu_path = LibSYCLDir + __SYCL_NATIVE_CPU_PLUGIN_NAME; - dllMap.emplace(nativecpu_path, LoadLibraryA(nativecpu_path.c_str())); + auto nativecpu_path = LibSYCLDir / __SYCL_NATIVE_CPU_PLUGIN_NAME; + dllMap.emplace(nativecpu_path, LoadLibrary(nativecpu_path.wstring().c_str())); // Restore system error handling. (void)SetErrorMode(SavedMode); - if (!SetDllDirectoryA(nullptr)) { + if (!SetDllDirectory(nullptr)) { assert(false && "Failed to restore DLL search path"); } } /// windows_pi.cpp:loadOsPluginLibrary() calls this to get the DLL loaded /// earlier. -__declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath) { +__declspec(dllexport) void *getPreloadedPlugin(const std::filesystem::path &PluginPath) { MapT &dllMap = getDllMap(); @@ -188,11 +208,11 @@ __declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath) { // which is perfectly valid. if (match == dllMap.end()) { // unit testing? return nullptr (not found) rather than risk asserting below - if (PluginPath.find("unittests") != std::string::npos) + if (PluginPath.string().find("unittests") != std::string::npos) return nullptr; // Otherwise, asking for something we don't know about at all, is an issue. - std::cout << "unknown plugin: " << PluginPath << std::endl; + std::cout << "unknown plugin: " << PluginPath.string() << std::endl; assert(false && "getPreloadedPlugin was given an unknown plugin path."); return nullptr; } diff --git a/sycl/pi_win_proxy_loader/pi_win_proxy_loader.hpp b/sycl/pi_win_proxy_loader/pi_win_proxy_loader.hpp index c1104a6d26c77..462b8f111f114 100644 --- a/sycl/pi_win_proxy_loader/pi_win_proxy_loader.hpp +++ b/sycl/pi_win_proxy_loader/pi_win_proxy_loader.hpp @@ -10,6 +10,7 @@ #ifdef _WIN32 #include +#include -__declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath); +__declspec(dllexport) void *getPreloadedPlugin(const std::filesystem::path &PluginPath); #endif diff --git a/sycl/source/detail/os_util.cpp b/sycl/source/detail/os_util.cpp index dde34762843f8..9383460ac249a 100644 --- a/sycl/source/detail/os_util.cpp +++ b/sycl/source/detail/os_util.cpp @@ -10,6 +10,7 @@ #include #include +#include #if defined(__SYCL_RT_OS_LINUX) @@ -138,6 +139,11 @@ std::string OSUtil::getDirName(const char *Path) { return Tmp; } +std::filesystem::path OSUtil::getCurrentDSODirPath() { + return std::filesystem::path(OSUtil::getCurrentDSODir()); +} + + #elif defined(__SYCL_RT_OS_WINDOWS) // TODO: Just inline it. using OSModuleHandle = intptr_t; @@ -192,6 +198,26 @@ std::string OSUtil::getDirName(const char *Path) { return Tmp; } +std::filesystem::path OSUtil::getCurrentDSODirPath() { + wchar_t Path[MAX_PATH]; + //Path[0] = '\0'; + //Path[sizeof(Path) - 1] = '\0'; + auto Handle = getOSModuleHandle(reinterpret_cast(&getCurrentDSODir)); + DWORD Ret = GetModuleFileName( + reinterpret_cast(ExeModuleHandle == Handle ? 0 : Handle), + reinterpret_cast(&Path), sizeof(Path)); + assert(Ret < sizeof(Path) && "Path is longer than PATH_MAX?"); + assert(Ret > 0 && "GetModuleFileName failed"); + (void)Ret; + + BOOL RetCode = PathRemoveFileSpec(reinterpret_cast(&Path)); + assert(RetCode && "PathRemoveFileSpec failed"); + (void)RetCode; + + return std::filesystem::path(Path); +} + + #elif defined(__SYCL_RT_OS_DARWIN) std::string OSUtil::getCurrentDSODir() { auto CurrentFunc = reinterpret_cast(&getCurrentDSODir); @@ -208,6 +234,10 @@ std::string OSUtil::getCurrentDSODir() { return Path.substr(0, LastSlashPos); } +std::filesystem::path OSUtil::getCurrentDSODirPath() { + return std::filesystem::path(OSUtil::getCurrentDSODir()); +} + #endif // __SYCL_RT_OS size_t OSUtil::getOSMemSize() { diff --git a/sycl/source/detail/pi.cpp b/sycl/source/detail/pi.cpp index 33dfdaf005e41..305ca22c49758 100644 --- a/sycl/source/detail/pi.cpp +++ b/sycl/source/detail/pi.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -385,7 +386,7 @@ std::vector> findPlugins() { // Load the Plugin by calling the OS dependent library loading call. // Return the handle to the Library. -void *loadPlugin(const std::string &PluginPath) { +void *loadPlugin(const std::filesystem::path &PluginPath) { return loadOsPluginLibrary(PluginPath); } @@ -442,15 +443,14 @@ static void initializePlugins(std::vector &Plugins) { std::cerr << "SYCL_PI_TRACE[all]: " << "No Plugins Found." << std::endl; - const std::string LibSYCLDir = - sycl::detail::OSUtil::getCurrentDSODir() + sycl::detail::OSUtil::DirSep; + std::filesystem::path LibSYCLDir = sycl::detail::OSUtil::getCurrentDSODirPath(); for (unsigned int I = 0; I < PluginNames.size(); I++) { std::shared_ptr PluginInformation = std::make_shared( PiPlugin{_PI_H_VERSION_STRING, _PI_H_VERSION_STRING, /*Targets=*/nullptr, /*FunctionPointers=*/{}}); - void *Library = loadPlugin(LibSYCLDir + PluginNames[I].first); + void *Library = loadPlugin(LibSYCLDir / PluginNames[I].first); // loadPlugin(path) if (!Library) { if (trace(PI_TRACE_ALL)) { diff --git a/sycl/source/detail/posix_pi.cpp b/sycl/source/detail/posix_pi.cpp index e72f4d8b0af2f..3c1a0672149e7 100644 --- a/sycl/source/detail/posix_pi.cpp +++ b/sycl/source/detail/posix_pi.cpp @@ -12,25 +12,26 @@ #include #include +#include namespace sycl { inline namespace _V1 { namespace detail::pi { -void *loadOsLibrary(const std::string &LibraryPath) { +void *loadOsLibrary(const std::filesystem::path &LibraryPath) { // TODO: Check if the option RTLD_NOW is correct. Explore using // RTLD_DEEPBIND option when there are multiple plugins. - void *so = dlopen(LibraryPath.c_str(), RTLD_NOW); + void *so = dlopen(LibraryPath.string().c_str(), RTLD_NOW); if (!so && trace(TraceLevel::PI_TRACE_ALL)) { char *Error = dlerror(); - std::cerr << "SYCL_PI_TRACE[-1]: dlopen(" << LibraryPath + std::cerr << "SYCL_PI_TRACE[-1]: dlopen(" << LibraryPath.string() << ") failed with <" << (Error ? Error : "unknown error") << ">" << std::endl; } return so; } -void *loadOsPluginLibrary(const std::string &PluginPath) { +void *loadOsPluginLibrary(const std::filesystem::path &PluginPath) { return loadOsLibrary(PluginPath); } diff --git a/sycl/source/detail/windows_pi.cpp b/sycl/source/detail/windows_pi.cpp index 1b99d73b89657..fae0853b5af3b 100644 --- a/sycl/source/detail/windows_pi.cpp +++ b/sycl/source/detail/windows_pi.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include "pi_win_proxy_loader.hpp" @@ -20,7 +21,7 @@ inline namespace _V1 { namespace detail { namespace pi { -void *loadOsLibrary(const std::string &LibraryPath) { +void *loadOsLibrary(const std::filesystem::path &LibraryPath) { // Tells the system to not display the critical-error-handler message box. // Instead, the system sends the error to the calling process. // This is crucial for graceful handling of shared libs that can't be @@ -28,19 +29,19 @@ void *loadOsLibrary(const std::string &LibraryPath) { UINT SavedMode = SetErrorMode(SEM_FAILCRITICALERRORS); // Exclude current directory from DLL search path - if (!SetDllDirectoryA("")) { + if (!SetDllDirectory(L"")) { assert(false && "Failed to update DLL search path"); } - auto Result = (void *)LoadLibraryA(LibraryPath.c_str()); + auto Result = (void *)LoadLibrary(LibraryPath.wstring().c_str()); (void)SetErrorMode(SavedMode); - if (!SetDllDirectoryA(nullptr)) { + if (!SetDllDirectory(nullptr)) { assert(false && "Failed to restore DLL search path"); } return Result; } -void *loadOsPluginLibrary(const std::string &PluginPath) { +void *loadOsPluginLibrary(const std::filesystem::path &PluginPath) { // We fetch the preloaded plugin from the pi_win_proxy_loader. // The proxy_loader handles any required error suppression. auto Result = getPreloadedPlugin(PluginPath);