Skip to content

Commit

Permalink
refactor loader lifetime management
Browse files Browse the repository at this point in the history
This patch implements an atomic singleton class for managing
the lifecycle of the context objects inside of the loader.
This class ensures that the contexts always exist and lets the
loader manually destroy them on user request (during teardown).

Thanks to this change, the loader no longer relies on the order
of library constructors and destructors. It also gets us 90%
towards allowing the loader to be statically linked with the
application.
  • Loading branch information
pbalcer committed Jul 8, 2024
1 parent 731376d commit 281f027
Show file tree
Hide file tree
Showing 42 changed files with 3,235 additions and 2,928 deletions.
12 changes: 6 additions & 6 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ namespace ur_loader
size_t adapterIndex = 0;
if( nullptr != ${obj['params'][1]['name']} && ${obj['params'][0]['name']} !=0)
{
for( auto& platform : context->platforms )
for( auto& platform : getContext()->platforms )
{
if(platform.initStatus != ${X}_RESULT_SUCCESS)
continue;
Expand All @@ -81,7 +81,7 @@ namespace ur_loader

if( ${obj['params'][2]['name']} != nullptr )
{
*${obj['params'][2]['name']} = static_cast<uint32_t>(context->platforms.size());
*${obj['params'][2]['name']} = static_cast<uint32_t>(getContext()->platforms.size());
}

%elif re.match(r"\w+PlatformGet$", th.make_func_name(n, tags, obj)):
Expand Down Expand Up @@ -360,13 +360,13 @@ ${tbl['export']['name']}(
if( nullptr == pDdiTable )
return ${X}_RESULT_ERROR_INVALID_NULL_POINTER;
if( ur_loader::context->version < version )
if( ur_loader::getContext()->version < version )
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;
${x}_result_t result = ${X}_RESULT_SUCCESS;
// Load the device-platform DDI tables
for( auto& platform : ur_loader::context->platforms )
for( auto& platform : ur_loader::getContext()->platforms )
{
if(platform.initStatus != ${X}_RESULT_SUCCESS)
continue;
Expand All @@ -379,7 +379,7 @@ ${tbl['export']['name']}(
if( ${X}_RESULT_SUCCESS == result )
{
if( ur_loader::context->platforms.size() != 1 || ur_loader::context->forceIntercept )
if( ur_loader::getContext()->platforms.size() != 1 || ur_loader::getContext()->forceIntercept )
{
// return pointers to loader's DDIs
%for obj in tbl['functions']:
Expand All @@ -397,7 +397,7 @@ ${tbl['export']['name']}(
else
{
// return pointers directly to platform's DDIs
*pDdiTable = ur_loader::context->platforms.front().dditable.${n}.${tbl['name']};
*pDdiTable = ur_loader::getContext()->platforms.front().dditable.${n}.${tbl['name']};
}
}
Expand Down
22 changes: 2 additions & 20 deletions scripts/templates/libapi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -56,28 +56,10 @@ ${th.make_func_name(n, tags, obj)}(
%endfor
)
try {
%if re.match("Init", obj['name']):
<%
param_checks=th.make_param_checks(n, tags, obj, meta=meta).items()
%>
%for key, values in param_checks:
%for val in values:
if( ${val} )
return ${key};

%endfor
%endfor

static ${x}_result_t result = ${X}_RESULT_SUCCESS;
std::call_once(${x}_lib::context->initOnce, [device_flags, hLoaderConfig]() {
result = ${x}_lib::context->Init(device_flags, hLoaderConfig);
});

return result;
%elif th.obj_traits.is_loader_only(obj):
%if th.obj_traits.is_loader_only(obj):
return ur_lib::${th.make_func_name(n, tags, obj)}(${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
%else:
auto ${th.make_pfn_name(n, tags, obj)} = ${x}_lib::context->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
auto ${th.make_pfn_name(n, tags, obj)} = ${x}_lib::getContext()->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
return ${X}_RESULT_ERROR_UNINITIALIZED;

Expand Down
2 changes: 1 addition & 1 deletion scripts/templates/libddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace ${x}_lib
///////////////////////////////////////////////////////////////////////////////


__${x}dlllocal ${x}_result_t context_t::${n}LoaderInit()
__${x}dlllocal ${x}_result_t context_t::ddiInit()
{
${x}_result_t result = ${X}_RESULT_SUCCESS;

Expand Down
18 changes: 9 additions & 9 deletions scripts/templates/trcddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,23 @@ namespace ur_tracing_layer
%endfor
)
{
auto ${th.make_pfn_name(n, tags, obj)} = context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
auto ${th.make_pfn_name(n, tags, obj)} = getContext()->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};

if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
return ${X}_RESULT_ERROR_UNSUPPORTED_FEATURE;

${th.make_pfncb_param_type(n, tags, obj)} params = { &${",&".join(th.make_param_lines(n, tags, obj, format=["name"]))} };
uint64_t instance = context.notify_begin(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params);
uint64_t instance = getContext()->notify_begin(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params);

context.logger.info("---> ${th.make_func_name(n, tags, obj)}");
getContext()->logger.info("---> ${th.make_func_name(n, tags, obj)}");

${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );

context.notify_end(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params, &result, instance);
getContext()->notify_end(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params, &result, instance);

std::ostringstream args_str;
ur::extras::printFunctionParams(args_str, ${th.make_func_etor(n, tags, obj)}, &params);
context.logger.info("({}) -> {};\n", args_str.str(), result);
getContext()->logger.info("({}) -> {};\n", args_str.str(), result);

return result;
}
Expand All @@ -79,13 +79,13 @@ namespace ur_tracing_layer
%endfor
)
{
auto& dditable = ur_tracing_layer::context.${n}DdiTable.${tbl['name']};
auto& dditable = ur_tracing_layer::getContext()->${n}DdiTable.${tbl['name']};

if( nullptr == pDdiTable )
return ${X}_RESULT_ERROR_INVALID_NULL_POINTER;

if (UR_MAJOR_VERSION(ur_tracing_layer::context.version) != UR_MAJOR_VERSION(version) ||
UR_MINOR_VERSION(ur_tracing_layer::context.version) > UR_MINOR_VERSION(version))
if (UR_MAJOR_VERSION(ur_tracing_layer::getContext()->version) != UR_MAJOR_VERSION(version) ||
UR_MINOR_VERSION(ur_tracing_layer::getContext()->version) > UR_MINOR_VERSION(version))
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;

${x}_result_t result = ${X}_RESULT_SUCCESS;
Expand Down Expand Up @@ -122,7 +122,7 @@ namespace ur_tracing_layer
// program launch and the call to `urLoaderInit`
logger = logger::create_logger("tracing", true, true);

ur_tracing_layer::context.codelocData = codelocData;
ur_tracing_layer::getContext()->codelocData = codelocData;

%for tbl in th.get_pfntables(specs, meta, n, tags):
if( ${X}_RESULT_SUCCESS == result )
Expand Down
20 changes: 10 additions & 10 deletions scripts/templates/valddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ namespace ur_validation_layer
%endfor
)
{
auto ${th.make_pfn_name(n, tags, obj)} = context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
auto ${th.make_pfn_name(n, tags, obj)} = getContext()->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};

if( nullptr == ${th.make_pfn_name(n, tags, obj)} ) {
return ${X}_RESULT_ERROR_UNINITIALIZED;
}

if( context.enableParameterValidation )
if( getContext()->enableParameterValidation )
{
%for key, values in sorted_param_checks:
%for val in values:
Expand All @@ -80,7 +80,7 @@ namespace ur_validation_layer
is_related_create_get_retain_release_func = any(func_name in funcs for funcs in tp_input_handle_funcs.values())
%>
%if tp_input_handle_funcs and not is_related_create_get_retain_release_func:
if (context.enableLifetimeValidation && !refCountContext.isReferenceValid(${tp['name']})) {
if (getContext()->enableLifetimeValidation && !refCountContext.isReferenceValid(${tp['name']})) {
refCountContext.logInvalidReference(${tp['name']});
}
%endif
Expand All @@ -94,24 +94,24 @@ namespace ur_validation_layer
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
%>
%if func_name in tp_handle_funcs['create']:
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
{
refCountContext.createRefCount(*${tp['name']});
}
%elif func_name in tp_handle_funcs['get']:
if( context.enableLeakChecking && ${tp['name']} && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking && ${tp['name']} && result == UR_RESULT_SUCCESS )
{
for (uint32_t i = ${th.param_traits.range_start(tp)}; i < ${th.param_traits.range_end(tp)}; i++) {
refCountContext.createOrIncrementRefCount(${tp['name']}[i], ${str(is_handle_to_adapter).lower()});
}
}
%elif func_name in tp_handle_funcs['retain']:
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
{
refCountContext.incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%elif func_name in tp_handle_funcs['release']:
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
{
refCountContext.decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
Expand Down Expand Up @@ -141,13 +141,13 @@ namespace ur_validation_layer
%endfor
)
{
auto& dditable = ur_validation_layer::context.${n}DdiTable.${tbl['name']};
auto& dditable = ur_validation_layer::getContext()->${n}DdiTable.${tbl['name']};

if( nullptr == pDdiTable )
return ${X}_RESULT_ERROR_INVALID_NULL_POINTER;

if (UR_MAJOR_VERSION(ur_validation_layer::context.version) != UR_MAJOR_VERSION(version) ||
UR_MINOR_VERSION(ur_validation_layer::context.version) > UR_MINOR_VERSION(version))
if (UR_MAJOR_VERSION(ur_validation_layer::getContext()->version) != UR_MAJOR_VERSION(version) ||
UR_MINOR_VERSION(ur_validation_layer::getContext()->version) > UR_MINOR_VERSION(version))
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;

${x}_result_t result = ${X}_RESULT_SUCCESS;
Expand Down
106 changes: 106 additions & 0 deletions source/common/ur_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@
*/

#include <algorithm>
#include <functional>
#include <memory>
#include <thread>
#ifndef UR_UTIL_H
#define UR_UTIL_H 1

#include <ur_api.h>

#include <atomic>
#include <iostream>
#include <map>
#include <optional>
Expand Down Expand Up @@ -343,4 +347,106 @@ splitMetadataName(const std::string &metadataName) {
return std::make_pair(metadataName.substr(0, splitPos),
metadataName.substr(splitPos, metadataName.length()));
}

// A simple spinlock, must be kept trivially destructible
// so that it's safe to use after its destructor has been called.
template <typename T> class Spinlock {
public:
Spinlock() : lock(ATOMIC_FLAG_INIT) {}

T *acquire() {
while (lock.test_and_set(std::memory_order_acquire)) {
std::this_thread::yield();
}
return &value;
}
void release() { lock.clear(std::memory_order_release); }

private:
std::atomic_flag lock;
T value;
};

// A reference counted pointer.
template <typename T> class Rc {
public:
Rc() : ptr(nullptr), refcount(0) {}
Rc(const Rc &) = delete;
Rc &operator=(const Rc &) = delete;
Rc(Rc &&) = delete;
Rc &operator=(Rc &&) = delete;

T *get() {
if (ptr == nullptr) {
ptr = new T();
}
refcount++;
return ptr;
}

int release(std::function<void(T *)> deleter) {
if (refcount <= 0) {
return -1;
}

if (--refcount == 0) {
deleter(ptr);
ptr = nullptr;
}

return 0;
}

void forceDelete() {
delete ptr;
refcount = 0;
ptr = nullptr;
}

private:
T *ptr;
size_t refcount;
};

// AtomicSingleton is for those cases where we want to support creating state
// on first use, global MT-safe reference-counted access, explicit synchronized deletion,
// and, on top of all that, need to gracefully handle situations where destructor order
// causes a library/application to call into the loader after it has been destroyed.
template <typename T> class AtomicSingleton {
private:
static Spinlock<Rc<T>> instance;
// Simply using an std::mutex would have been much simpler, but mutexes might
// get deleted prior to last use of this type.

public:
static T *get() {
auto val = instance.acquire();

auto ptr = val->get();

instance.release();

return ptr;
}

static int release(std::function<void(T *)> deleter) {
auto val = instance.acquire();
int ret = val->release(deleter);
instance.release();

return ret;
}

// When we don't care about the refcount or the refcount is external.
static void forceDelete() {
auto val = instance.acquire();

val->forceDelete();

instance.release();
}
};

template <typename T> Spinlock<Rc<T>> AtomicSingleton<T>::instance;

#endif /* UR_UTIL_H */
4 changes: 0 additions & 4 deletions source/loader/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,10 @@ if(WIN32)
target_sources(ur_loader
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/windows/adapter_search.cpp
${CMAKE_CURRENT_SOURCE_DIR}/windows/lib_init.cpp
${CMAKE_CURRENT_SOURCE_DIR}/windows/loader_init.cpp
)
else()
target_sources(ur_loader
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/linux/adapter_search.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linux/lib_init.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linux/loader_init.cpp
)
endif()
2 changes: 1 addition & 1 deletion source/loader/layers/sanitizer/asan_allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
namespace ur_sanitizer_layer {

void AllocInfo::print() {
context.logger.info(
getContext()->logger.info(
"AllocInfo(Alloc=[{}-{}), User=[{}-{}), AllocSize={}, Type={})",
(void *)AllocBegin, (void *)(AllocBegin + AllocSize), (void *)UserBegin,
(void *)(UserEnd), AllocSize, ToString(Type));
Expand Down
Loading

0 comments on commit 281f027

Please sign in to comment.