Skip to content

Commit

Permalink
refactor loader lifetime management
Browse files Browse the repository at this point in the history
This patch implements an atomic singleton class for managing
the lifecycle of the context objects inside of the loader.
This class ensures that the contexts always exist and lets the
loader manually destroy them on user request (during teardown).

Thanks to this change, the loader no longer relies on the order
of library constructors and destructors. It also gets us 90%
towards allowing the loader to be statically linked with the
application.
  • Loading branch information
pbalcer committed Jul 5, 2024
1 parent 731376d commit 905a066
Show file tree
Hide file tree
Showing 40 changed files with 3,053 additions and 2,923 deletions.
12 changes: 6 additions & 6 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ namespace ur_loader
size_t adapterIndex = 0;
if( nullptr != ${obj['params'][1]['name']} && ${obj['params'][0]['name']} !=0)
{
for( auto& platform : context->platforms )
for( auto& platform : getContext()->platforms )
{
if(platform.initStatus != ${X}_RESULT_SUCCESS)
continue;
Expand All @@ -81,7 +81,7 @@ namespace ur_loader

if( ${obj['params'][2]['name']} != nullptr )
{
*${obj['params'][2]['name']} = static_cast<uint32_t>(context->platforms.size());
*${obj['params'][2]['name']} = static_cast<uint32_t>(getContext()->platforms.size());
}

%elif re.match(r"\w+PlatformGet$", th.make_func_name(n, tags, obj)):
Expand Down Expand Up @@ -360,13 +360,13 @@ ${tbl['export']['name']}(
if( nullptr == pDdiTable )
return ${X}_RESULT_ERROR_INVALID_NULL_POINTER;
if( ur_loader::context->version < version )
if( ur_loader::getContext()->version < version )
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;
${x}_result_t result = ${X}_RESULT_SUCCESS;
// Load the device-platform DDI tables
for( auto& platform : ur_loader::context->platforms )
for( auto& platform : ur_loader::getContext()->platforms )
{
if(platform.initStatus != ${X}_RESULT_SUCCESS)
continue;
Expand All @@ -379,7 +379,7 @@ ${tbl['export']['name']}(
if( ${X}_RESULT_SUCCESS == result )
{
if( ur_loader::context->platforms.size() != 1 || ur_loader::context->forceIntercept )
if( ur_loader::getContext()->platforms.size() != 1 || ur_loader::getContext()->forceIntercept )
{
// return pointers to loader's DDIs
%for obj in tbl['functions']:
Expand All @@ -397,7 +397,7 @@ ${tbl['export']['name']}(
else
{
// return pointers directly to platform's DDIs
*pDdiTable = ur_loader::context->platforms.front().dditable.${n}.${tbl['name']};
*pDdiTable = ur_loader::getContext()->platforms.front().dditable.${n}.${tbl['name']};
}
}
Expand Down
22 changes: 2 additions & 20 deletions scripts/templates/libapi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -56,28 +56,10 @@ ${th.make_func_name(n, tags, obj)}(
%endfor
)
try {
%if re.match("Init", obj['name']):
<%
param_checks=th.make_param_checks(n, tags, obj, meta=meta).items()
%>
%for key, values in param_checks:
%for val in values:
if( ${val} )
return ${key};

%endfor
%endfor

static ${x}_result_t result = ${X}_RESULT_SUCCESS;
std::call_once(${x}_lib::context->initOnce, [device_flags, hLoaderConfig]() {
result = ${x}_lib::context->Init(device_flags, hLoaderConfig);
});

return result;
%elif th.obj_traits.is_loader_only(obj):
%if th.obj_traits.is_loader_only(obj):
return ur_lib::${th.make_func_name(n, tags, obj)}(${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
%else:
auto ${th.make_pfn_name(n, tags, obj)} = ${x}_lib::context->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
auto ${th.make_pfn_name(n, tags, obj)} = ${x}_lib::getContext()->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
return ${X}_RESULT_ERROR_UNINITIALIZED;

Expand Down
2 changes: 1 addition & 1 deletion scripts/templates/libddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace ${x}_lib
///////////////////////////////////////////////////////////////////////////////


__${x}dlllocal ${x}_result_t context_t::${n}LoaderInit()
__${x}dlllocal ${x}_result_t context_t::ddiInit()
{
${x}_result_t result = ${X}_RESULT_SUCCESS;

Expand Down
18 changes: 9 additions & 9 deletions scripts/templates/trcddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,23 @@ namespace ur_tracing_layer
%endfor
)
{
auto ${th.make_pfn_name(n, tags, obj)} = context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
auto ${th.make_pfn_name(n, tags, obj)} = getContext()->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};

if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
return ${X}_RESULT_ERROR_UNSUPPORTED_FEATURE;

${th.make_pfncb_param_type(n, tags, obj)} params = { &${",&".join(th.make_param_lines(n, tags, obj, format=["name"]))} };
uint64_t instance = context.notify_begin(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params);
uint64_t instance = getContext()->notify_begin(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params);

context.logger.info("---> ${th.make_func_name(n, tags, obj)}");
getContext()->logger.info("---> ${th.make_func_name(n, tags, obj)}");

${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );

context.notify_end(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params, &result, instance);
getContext()->notify_end(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params, &result, instance);

std::ostringstream args_str;
ur::extras::printFunctionParams(args_str, ${th.make_func_etor(n, tags, obj)}, &params);
context.logger.info("({}) -> {};\n", args_str.str(), result);
getContext()->logger.info("({}) -> {};\n", args_str.str(), result);

return result;
}
Expand All @@ -79,13 +79,13 @@ namespace ur_tracing_layer
%endfor
)
{
auto& dditable = ur_tracing_layer::context.${n}DdiTable.${tbl['name']};
auto& dditable = ur_tracing_layer::getContext()->${n}DdiTable.${tbl['name']};

if( nullptr == pDdiTable )
return ${X}_RESULT_ERROR_INVALID_NULL_POINTER;

if (UR_MAJOR_VERSION(ur_tracing_layer::context.version) != UR_MAJOR_VERSION(version) ||
UR_MINOR_VERSION(ur_tracing_layer::context.version) > UR_MINOR_VERSION(version))
if (UR_MAJOR_VERSION(ur_tracing_layer::getContext()->version) != UR_MAJOR_VERSION(version) ||
UR_MINOR_VERSION(ur_tracing_layer::getContext()->version) > UR_MINOR_VERSION(version))
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;

${x}_result_t result = ${X}_RESULT_SUCCESS;
Expand Down Expand Up @@ -122,7 +122,7 @@ namespace ur_tracing_layer
// program launch and the call to `urLoaderInit`
logger = logger::create_logger("tracing", true, true);

ur_tracing_layer::context.codelocData = codelocData;
ur_tracing_layer::getContext()->codelocData = codelocData;

%for tbl in th.get_pfntables(specs, meta, n, tags):
if( ${X}_RESULT_SUCCESS == result )
Expand Down
20 changes: 10 additions & 10 deletions scripts/templates/valddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ namespace ur_validation_layer
%endfor
)
{
auto ${th.make_pfn_name(n, tags, obj)} = context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
auto ${th.make_pfn_name(n, tags, obj)} = getContext()->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};

if( nullptr == ${th.make_pfn_name(n, tags, obj)} ) {
return ${X}_RESULT_ERROR_UNINITIALIZED;
}

if( context.enableParameterValidation )
if( getContext()->enableParameterValidation )
{
%for key, values in sorted_param_checks:
%for val in values:
Expand All @@ -80,7 +80,7 @@ namespace ur_validation_layer
is_related_create_get_retain_release_func = any(func_name in funcs for funcs in tp_input_handle_funcs.values())
%>
%if tp_input_handle_funcs and not is_related_create_get_retain_release_func:
if (context.enableLifetimeValidation && !refCountContext.isReferenceValid(${tp['name']})) {
if (getContext()->enableLifetimeValidation && !refCountContext.isReferenceValid(${tp['name']})) {
refCountContext.logInvalidReference(${tp['name']});
}
%endif
Expand All @@ -94,24 +94,24 @@ namespace ur_validation_layer
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
%>
%if func_name in tp_handle_funcs['create']:
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
{
refCountContext.createRefCount(*${tp['name']});
}
%elif func_name in tp_handle_funcs['get']:
if( context.enableLeakChecking && ${tp['name']} && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking && ${tp['name']} && result == UR_RESULT_SUCCESS )
{
for (uint32_t i = ${th.param_traits.range_start(tp)}; i < ${th.param_traits.range_end(tp)}; i++) {
refCountContext.createOrIncrementRefCount(${tp['name']}[i], ${str(is_handle_to_adapter).lower()});
}
}
%elif func_name in tp_handle_funcs['retain']:
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
{
refCountContext.incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%elif func_name in tp_handle_funcs['release']:
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
{
refCountContext.decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
Expand Down Expand Up @@ -141,13 +141,13 @@ namespace ur_validation_layer
%endfor
)
{
auto& dditable = ur_validation_layer::context.${n}DdiTable.${tbl['name']};
auto& dditable = ur_validation_layer::getContext()->${n}DdiTable.${tbl['name']};

if( nullptr == pDdiTable )
return ${X}_RESULT_ERROR_INVALID_NULL_POINTER;

if (UR_MAJOR_VERSION(ur_validation_layer::context.version) != UR_MAJOR_VERSION(version) ||
UR_MINOR_VERSION(ur_validation_layer::context.version) > UR_MINOR_VERSION(version))
if (UR_MAJOR_VERSION(ur_validation_layer::getContext()->version) != UR_MAJOR_VERSION(version) ||
UR_MINOR_VERSION(ur_validation_layer::getContext()->version) > UR_MINOR_VERSION(version))
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;

${x}_result_t result = ${X}_RESULT_SUCCESS;
Expand Down
34 changes: 34 additions & 0 deletions source/common/ur_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <ur_api.h>

#include <atomic>
#include <iostream>
#include <map>
#include <optional>
Expand Down Expand Up @@ -343,4 +344,37 @@ splitMetadataName(const std::string &metadataName) {
return std::make_pair(metadataName.substr(0, splitPos),
metadataName.substr(splitPos, metadataName.length()));
}

template <typename T> class AtomicSingleton {
private:
static std::atomic<T *> instance;

public:
static T *get() {
T *current = instance.load(std::memory_order_acquire);

if (current == nullptr) {
T *newContext = new T();
if (!instance.compare_exchange_strong(current, newContext,
std::memory_order_acq_rel)) {
delete newContext;
} else {
current = newContext;
}
}
return current;
}

static void destroy() {
T *current = instance.load(std::memory_order_acquire);

if (current != nullptr &&
instance.compare_exchange_strong(current, nullptr,
std::memory_order_acq_rel)) {
delete current;
}
}
};
template <typename T> std::atomic<T *> AtomicSingleton<T>::instance(nullptr);

#endif /* UR_UTIL_H */
4 changes: 0 additions & 4 deletions source/loader/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,10 @@ if(WIN32)
target_sources(ur_loader
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/windows/adapter_search.cpp
${CMAKE_CURRENT_SOURCE_DIR}/windows/lib_init.cpp
${CMAKE_CURRENT_SOURCE_DIR}/windows/loader_init.cpp
)
else()
target_sources(ur_loader
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/linux/adapter_search.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linux/lib_init.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linux/loader_init.cpp
)
endif()
2 changes: 1 addition & 1 deletion source/loader/layers/sanitizer/asan_allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
namespace ur_sanitizer_layer {

void AllocInfo::print() {
context.logger.info(
getContext()->logger.info(
"AllocInfo(Alloc=[{}-{}), User=[{}-{}), AllocSize={}, Type={})",
(void *)AllocBegin, (void *)(AllocBegin + AllocSize), (void *)UserBegin,
(void *)(UserEnd), AllocSize, ToString(Type));
Expand Down
24 changes: 13 additions & 11 deletions source/loader/layers/sanitizer/asan_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,16 @@ ur_result_t EnqueueMemCopyRectHelper(
// loop call 2D memory copy function to implement it.
for (size_t i = 0; i < Region.depth; i++) {
ur_event_handle_t NewEvent{};
UR_CALL(context.urDdiTable.Enqueue.pfnUSMMemcpy2D(
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D(
Queue, Blocking, DstOrigin + (i * DstSlicePitch), DstRowPitch,
SrcOrigin + (i * SrcSlicePitch), SrcRowPitch, Region.width,
Region.height, NumEventsInWaitList, EventWaitList, &NewEvent));

Events.push_back(NewEvent);
}

UR_CALL(context.urDdiTable.Enqueue.pfnEventsWait(Queue, Events.size(),
Events.data(), Event));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
Queue, Events.size(), Events.data(), Event));

return UR_RESULT_SUCCESS;
}
Expand All @@ -80,23 +80,24 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
ur_usm_desc_t USMDesc{};
USMDesc.align = getAlignment();
ur_usm_pool_handle_t Pool{};
ur_result_t URes = context.interceptor->allocateMemory(
ur_result_t URes = getContext()->interceptor->allocateMemory(
Context, Device, &USMDesc, Pool, Size, AllocType::MEM_BUFFER,
ur_cast<void **>(&Allocation));
if (URes != UR_RESULT_SUCCESS) {
context.logger.error(
getContext()->logger.error(
"Failed to allocate {} bytes memory for buffer {}", Size, this);
return URes;
}

if (HostPtr) {
ManagedQueue Queue(Context, Device);
URes = context.urDdiTable.Enqueue.pfnUSMMemcpy(
URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, Allocation, HostPtr, Size, 0, nullptr, nullptr);
if (URes != UR_RESULT_SUCCESS) {
context.logger.error("Failed to copy {} bytes data from host "
"pointer {} to buffer {}",
Size, HostPtr, this);
getContext()->logger.error(
"Failed to copy {} bytes data from host "
"pointer {} to buffer {}",
Size, HostPtr, this);
return URes;
}
}
Expand All @@ -109,9 +110,10 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {

ur_result_t MemBuffer::free() {
for (const auto &[_, Ptr] : Allocations) {
ur_result_t URes = context.interceptor->releaseMemory(Context, Ptr);
ur_result_t URes =
getContext()->interceptor->releaseMemory(Context, Ptr);
if (URes != UR_RESULT_SUCCESS) {
context.logger.error("Failed to free buffer handle {}", Ptr);
getContext()->logger.error("Failed to free buffer handle {}", Ptr);
return URes;
}
}
Expand Down
Loading

0 comments on commit 905a066

Please sign in to comment.