Skip to content

Commit

Permalink
Refactor UR function ptrs
Browse files Browse the repository at this point in the history
  • Loading branch information
callumfare committed Sep 3, 2024
1 parent 1dc8b92 commit 7287d30
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 19 deletions.
23 changes: 20 additions & 3 deletions sycl/include/sycl/detail/ur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,31 @@ enum class UrApiKind {
#undef _UR_API
};

struct UrFuncPtrMapT {
#define _UR_API(api) decltype(&::api) pfn_##api = nullptr;
#include <ur_api_funcs.def>
#undef _UR_API
};

template <UrApiKind UrApiOffset> struct UrFuncInfo {};

#ifdef _WIN32
void *GetWinProcAddress(void *module, const char *funcName);
inline void PopulateUrFuncPtrTable(UrFuncPtrMapT *funcs, void *module) {
#define _UR_API(api) \
funcs->pfn_##api = (decltype(&::api))GetWinProcAddress(module, #api);
#include <ur_api_funcs.def>
#undef _UR_API
}

#define _UR_API(api) \
template <> struct UrFuncInfo<UrApiKind::api> { \
using FuncPtrT = decltype(&::api); \
inline const char *getFuncName() { return #api; } \
inline FuncPtrT getFuncPtr(void *module) { \
inline FuncPtrT getFuncPtr(const UrFuncPtrMapT *funcs) { \
return funcs->pfn_##api; \
} \
inline FuncPtrT getFuncPtrFromModule(void *module) { \
return (FuncPtrT)GetWinProcAddress(module, #api); \
} \
};
Expand All @@ -72,7 +88,8 @@ void *GetWinProcAddress(void *module, const char *funcName);
template <> struct UrFuncInfo<UrApiKind::api> { \
using FuncPtrT = decltype(&::api); \
inline const char *getFuncName() { return #api; } \
constexpr inline FuncPtrT getFuncPtr(void *) { return &api; } \
constexpr inline FuncPtrT getFuncPtr(const void *) { return &api; } \
constexpr inline FuncPtrT getFuncPtrFromModule(void *) { return &api; } \
};
#include <ur_api_funcs.def>
#undef _UR_API
Expand Down Expand Up @@ -106,7 +123,7 @@ int unloadOsLibrary(void *Library);
// library, implementation is OS dependent.
void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName);

void *loadURLoaderLibrary();
void *getURLoaderLibrary();

// Performs UR one-time initialization.
std::vector<PluginPtr> &
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/global_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ void GlobalHandler::unloadPlugins() {

UrFuncInfo<UrApiKind::urLoaderTearDown> loaderTearDownInfo;
auto loaderTearDown =
loaderTearDownInfo.getFuncPtr(ur::loadURLoaderLibrary());
loaderTearDownInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
loaderTearDown();
// urLoaderTearDown();

Expand Down
6 changes: 4 additions & 2 deletions sycl/source/detail/plugin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class plugin {
MPluginMutex(std::make_shared<std::mutex>()) {

#ifdef _WIN32
UrLoaderHandle = ur::loadURLoaderLibrary();
UrLoaderHandle = ur::getURLoaderLibrary();
PopulateUrFuncPtrTable(&UrFuncPtrs, UrLoaderHandle);
#endif
}

Expand Down Expand Up @@ -123,7 +124,7 @@ class plugin {
ur_result_t R = UR_RESULT_SUCCESS;
if (!adapterReleased) {
detail::UrFuncInfo<UrApiOffset> UrApiInfo;
auto F = UrApiInfo.getFuncPtr(UrLoaderHandle);
auto F = UrApiInfo.getFuncPtr(&UrFuncPtrs);
R = F(Args...);
}
return R;
Expand Down Expand Up @@ -220,6 +221,7 @@ class plugin {
// index of this vector corresponds to the index in UrPlatforms vector.
std::vector<int> LastDeviceIds;
void *UrLoaderHandle = nullptr;
UrFuncPtrMapT UrFuncPtrs;
}; // class plugin

using PluginPtr = std::shared_ptr<plugin>;
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/posix_ur.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName) {
return dlsym(Library, FunctionName.c_str());
}

void *loadURLoaderLibrary() { return nullptr; }
void *getURLoaderLibrary() { return nullptr; }

} // namespace detail::ur
} // namespace _V1
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(

UrFuncInfo<UrApiKind::urProgramRelease> programReleaseInfo;
auto programRelease =
programReleaseInfo.getFuncPtr(ur::loadURLoaderLibrary());
programReleaseInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
ProgramPtr ProgramManaged(NativePrg, programRelease);

// Link a fallback implementation of device libraries if they are not
Expand Down Expand Up @@ -2555,7 +2555,7 @@ device_image_plain ProgramManager::build(const device_image_plain &DeviceImage,

UrFuncInfo<UrApiKind::urProgramRelease> programReleaseInfo;
auto programRelease =
programReleaseInfo.getFuncPtr(ur::loadURLoaderLibrary());
programReleaseInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
ProgramPtr ProgramManaged(NativePrg, programRelease);

// Link a fallback implementation of device libraries if they are not
Expand Down Expand Up @@ -2769,7 +2769,7 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
auto &Plugin = DeviceImpl->getPlugin();
UrFuncInfo<UrApiKind::urProgramRelease> programReleaseInfo;
auto programRelease =
programReleaseInfo.getFuncPtr(ur::loadURLoaderLibrary());
programReleaseInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
ProgramPtr ProgramManaged(Program, programRelease);

std::string CompileOpts;
Expand Down
16 changes: 8 additions & 8 deletions sycl/source/detail/ur.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,25 @@ static void initializePlugins(std::vector<PluginPtr> &Plugins,

UrFuncInfo<UrApiKind::urLoaderConfigCreate> loaderConfigCreateInfo;
auto loaderConfigCreate =
loaderConfigCreateInfo.getFuncPtr(ur::loadURLoaderLibrary());
loaderConfigCreateInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
UrFuncInfo<UrApiKind::urLoaderConfigEnableLayer> loaderConfigEnableLayerInfo;
auto loaderConfigEnableLayer =
loaderConfigEnableLayerInfo.getFuncPtr(ur::loadURLoaderLibrary());
loaderConfigEnableLayerInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
UrFuncInfo<UrApiKind::urLoaderConfigRelease> loaderConfigReleaseInfo;
auto loaderConfigRelease =
loaderConfigReleaseInfo.getFuncPtr(ur::loadURLoaderLibrary());
loaderConfigReleaseInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
UrFuncInfo<UrApiKind::urLoaderConfigSetCodeLocationCallback>
loaderConfigSetCodeLocationCallbackInfo;
auto loaderConfigSetCodeLocationCallback =
loaderConfigSetCodeLocationCallbackInfo.getFuncPtr(
ur::loadURLoaderLibrary());
loaderConfigSetCodeLocationCallbackInfo.getFuncPtrFromModule(
ur::getURLoaderLibrary());
UrFuncInfo<UrApiKind::urLoaderInit> loaderInitInfo;
auto loaderInit = loaderInitInfo.getFuncPtr(ur::loadURLoaderLibrary());
auto loaderInit = loaderInitInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
UrFuncInfo<UrApiKind::urAdapterGet> adapterGet_Info;
auto adapterGet = adapterGet_Info.getFuncPtr(ur::loadURLoaderLibrary());
auto adapterGet = adapterGet_Info.getFuncPtrFromModule(ur::getURLoaderLibrary());
UrFuncInfo<UrApiKind::urAdapterGetInfo> adapterGetInfoInfo;
auto adapterGetInfo =
adapterGetInfoInfo.getFuncPtr(ur::loadURLoaderLibrary());
adapterGetInfoInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());

bool OwnLoaderConfig = false;
// If we weren't provided with a custom config handle create our own.
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/windows_ur.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ static std::filesystem::path getCurrentDSODirPath() {
return std::filesystem::path(Path);
}

void *loadURLoaderLibrary() {
void *getURLoaderLibrary() {
const std::filesystem::path LibSYCLDir = getCurrentDSODirPath();
return getPreloadedPlugin(LibSYCLDir / std::string("ur_loader.dll"));
}
Expand Down

0 comments on commit 7287d30

Please sign in to comment.