Skip to content

Commit

Permalink
refactor loader lifetime management
Browse files Browse the repository at this point in the history
This patch implements an atomic singleton class for managing
the lifecycle of the context objects inside of the loader.
This class ensures that the contexts always exist and lets the
loader manually destroy them on user request (during teardown).

Thanks to this change, the loader no longer relies on the order
of library constructors and destructors. It also gets us 90%
towards allowing the loader to be statically linked with the
application.
  • Loading branch information
pbalcer committed Jul 11, 2024
1 parent 642e343 commit 4d1b56e
Show file tree
Hide file tree
Showing 50 changed files with 4,379 additions and 3,594 deletions.
32 changes: 11 additions & 21 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,6 @@ from templates import helper as th

namespace ur_loader
{
///////////////////////////////////////////////////////////////////////////////
%for obj in th.get_adapter_handles(specs):
%if 'class' in obj:
<%
_handle_t = th.subt(n, tags, obj['name'])
_factory_t = re.sub(r"(\w+)_handle_t", r"\1_factory_t", _handle_t)
_factory = re.sub(r"(\w+)_handle_t", r"\1_factory", _handle_t)
%>${th.append_ws(_factory_t, 35)} ${_factory};
%endif
%endfor

%for obj in th.get_adapter_functions(specs):
///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for ${th.make_func_name(n, tags, obj)}
Expand All @@ -51,6 +40,7 @@ namespace ur_loader
add_local = False
%>${th.get_initial_null_set(obj)}

[[maybe_unused]] auto context = getContext();
%if re.match(r"\w+AdapterGet$", th.make_func_name(n, tags, obj)):

size_t adapterIndex = 0;
Expand All @@ -63,7 +53,7 @@ namespace ur_loader
platform.dditable.${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( 1, &${obj['params'][1]['name']}[adapterIndex], nullptr );
try
{
${obj['params'][1]['name']}[adapterIndex] = reinterpret_cast<${n}_adapter_handle_t>(${n}_adapter_factory.getInstance(
${obj['params'][1]['name']}[adapterIndex] = reinterpret_cast<${n}_adapter_handle_t>(context->factories.${n}_adapter_factory.getInstance(
${obj['params'][1]['name']}[adapterIndex], &platform.dditable
));
}
Expand Down Expand Up @@ -114,7 +104,7 @@ namespace ur_loader
for( uint32_t i = 0; i < library_platform_handle_count; ++i ) {
uint32_t platform_index = total_platform_handle_count + i;
${obj['params'][3]['name']}[ platform_index ] = reinterpret_cast<${n}_platform_handle_t>(
${n}_platform_factory.getInstance( ${obj['params'][3]['name']}[ platform_index ], dditable ) );
context->factories.${n}_platform_factory.getInstance( ${obj['params'][3]['name']}[ platform_index ], dditable ) );
}
}
catch( std::bad_alloc& )
Expand Down Expand Up @@ -294,7 +284,7 @@ namespace ur_loader
for (size_t i = 0; i < nelements; ++i) {
if (handles[i] != nullptr) {
handles[i] = reinterpret_cast<${etor['type']}>(
${etor['factory']}.getInstance( handles[i], dditable ) );
context->factories.${etor['factory']}.getInstance( handles[i], dditable ) );
}
}
} break;
Expand All @@ -306,16 +296,16 @@ namespace ur_loader
// convert platform handles to loader handles
for( size_t i = ${item['range'][0]}; ( nullptr != ${item['name']} ) && ( i < ${item['range'][1]} ); ++i )
${item['name']}[ i ] = reinterpret_cast<${item['type']}>(
${item['factory']}.getInstance( ${item['name']}[ i ], dditable ) );
context->factories.${item['factory']}.getInstance( ${item['name']}[ i ], dditable ) );
%else:
// convert platform handle to loader handle
%if item['optional'] or th.always_wrap_outputs(obj):
if( nullptr != ${item['name']} )
*${item['name']} = reinterpret_cast<${item['type']}>(
${item['factory']}.getInstance( *${item['name']}, dditable ) );
context->factories.${item['factory']}.getInstance( *${item['name']}, dditable ) );
%else:
*${item['name']} = reinterpret_cast<${item['type']}>(
${item['factory']}.getInstance( *${item['name']}, dditable ) );
context->factories.${item['factory']}.getInstance( *${item['name']}, dditable ) );
%endif
%endif
}
Expand Down Expand Up @@ -360,13 +350,13 @@ ${tbl['export']['name']}(
if( nullptr == pDdiTable )
return ${X}_RESULT_ERROR_INVALID_NULL_POINTER;
if( ur_loader::context->version < version )
if( ur_loader::getContext()->version < version )
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;
${x}_result_t result = ${X}_RESULT_SUCCESS;
// Load the device-platform DDI tables
for( auto& platform : ur_loader::context->platforms )
for( auto& platform : ur_loader::getContext()->platforms )
{
if(platform.initStatus != ${X}_RESULT_SUCCESS)
continue;
Expand All @@ -379,7 +369,7 @@ ${tbl['export']['name']}(
if( ${X}_RESULT_SUCCESS == result )
{
if( ur_loader::context->platforms.size() != 1 || ur_loader::context->forceIntercept )
if( ur_loader::getContext()->platforms.size() != 1 || ur_loader::getContext()->forceIntercept )
{
// return pointers to loader's DDIs
%for obj in tbl['functions']:
Expand All @@ -397,7 +387,7 @@ ${tbl['export']['name']}(
else
{
// return pointers directly to platform's DDIs
*pDdiTable = ur_loader::context->platforms.front().dditable.${n}.${tbl['name']};
*pDdiTable = ur_loader::getContext()->platforms.front().dditable.${n}.${tbl['name']};
}
}
Expand Down
12 changes: 12 additions & 0 deletions scripts/templates/ldrddi.hpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,29 @@ from templates import helper as th
namespace ur_loader
{
///////////////////////////////////////////////////////////////////////////////
<%
factories = []
%>
%for obj in th.get_adapter_handles(specs):
%if 'class' in obj:
<%
_handle_t = th.subt(n, tags, obj['name'])
_object_t = re.sub(r"(\w+)_handle_t", r"\1_object_t", _handle_t)
_factory_t = re.sub(r"(\w+)_handle_t", r"\1_factory_t", _handle_t)
_factory = re.sub(r"(\w+)_handle_t", r"\1_factory", _handle_t)
factories.append((_factory_t, _factory))
%>using ${th.append_ws(_object_t, 35)} = object_t < ${_handle_t} >;
using ${th.append_ws(_factory_t, 35)} = singleton_factory_t < ${_object_t}, ${_handle_t} >;

%endif
%endfor

struct handle_factories {
%for (f_t, f) in factories:
${f_t} ${f};
%endfor
};

}

#endif /* UR_LOADER_LDRDDI_H */
22 changes: 2 additions & 20 deletions scripts/templates/libapi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -56,28 +56,10 @@ ${th.make_func_name(n, tags, obj)}(
%endfor
)
try {
%if re.match("Init", obj['name']):
<%
param_checks=th.make_param_checks(n, tags, obj, meta=meta).items()
%>
%for key, values in param_checks:
%for val in values:
if( ${val} )
return ${key};

%endfor
%endfor

static ${x}_result_t result = ${X}_RESULT_SUCCESS;
std::call_once(${x}_lib::context->initOnce, [device_flags, hLoaderConfig]() {
result = ${x}_lib::context->Init(device_flags, hLoaderConfig);
});

return result;
%elif th.obj_traits.is_loader_only(obj):
%if th.obj_traits.is_loader_only(obj):
return ur_lib::${th.make_func_name(n, tags, obj)}(${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
%else:
${th.get_initial_null_set(obj)}auto ${th.make_pfn_name(n, tags, obj)} = ${x}_lib::context->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
${th.get_initial_null_set(obj)}auto ${th.make_pfn_name(n, tags, obj)} = ${x}_lib::getContext()->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
return ${X}_RESULT_ERROR_UNINITIALIZED;

Expand Down
2 changes: 1 addition & 1 deletion scripts/templates/libddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace ${x}_lib
///////////////////////////////////////////////////////////////////////////////


__${x}dlllocal ${x}_result_t context_t::${n}LoaderInit()
__${x}dlllocal ${x}_result_t context_t::ddiInit()
{
${x}_result_t result = ${X}_RESULT_SUCCESS;

Expand Down
18 changes: 9 additions & 9 deletions scripts/templates/trcddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,23 @@ namespace ur_tracing_layer
%endfor
)
{${th.get_initial_null_set(obj)}
auto ${th.make_pfn_name(n, tags, obj)} = context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
auto ${th.make_pfn_name(n, tags, obj)} = getContext()->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};

if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
return ${X}_RESULT_ERROR_UNSUPPORTED_FEATURE;

${th.make_pfncb_param_type(n, tags, obj)} params = { &${",&".join(th.make_param_lines(n, tags, obj, format=["name"]))} };
uint64_t instance = context.notify_begin(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params);
uint64_t instance = getContext()->notify_begin(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params);

context.logger.info("---> ${th.make_func_name(n, tags, obj)}");
getContext()->logger.info("---> ${th.make_func_name(n, tags, obj)}");

${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );

context.notify_end(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params, &result, instance);
getContext()->notify_end(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params, &result, instance);

std::ostringstream args_str;
ur::extras::printFunctionParams(args_str, ${th.make_func_etor(n, tags, obj)}, &params);
context.logger.info("({}) -> {};\n", args_str.str(), result);
getContext()->logger.info("({}) -> {};\n", args_str.str(), result);

return result;
}
Expand All @@ -79,13 +79,13 @@ namespace ur_tracing_layer
%endfor
)
{
auto& dditable = ur_tracing_layer::context.${n}DdiTable.${tbl['name']};
auto& dditable = ur_tracing_layer::getContext()->${n}DdiTable.${tbl['name']};

if( nullptr == pDdiTable )
return ${X}_RESULT_ERROR_INVALID_NULL_POINTER;

if (UR_MAJOR_VERSION(ur_tracing_layer::context.version) != UR_MAJOR_VERSION(version) ||
UR_MINOR_VERSION(ur_tracing_layer::context.version) > UR_MINOR_VERSION(version))
if (UR_MAJOR_VERSION(ur_tracing_layer::getContext()->version) != UR_MAJOR_VERSION(version) ||
UR_MINOR_VERSION(ur_tracing_layer::getContext()->version) > UR_MINOR_VERSION(version))
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;

${x}_result_t result = ${X}_RESULT_SUCCESS;
Expand Down Expand Up @@ -122,7 +122,7 @@ namespace ur_tracing_layer
// program launch and the call to `urLoaderInit`
logger = logger::create_logger("tracing", true, true);

ur_tracing_layer::context.codelocData = codelocData;
ur_tracing_layer::getContext()->codelocData = codelocData;

%for tbl in th.get_pfntables(specs, meta, n, tags):
if( ${X}_RESULT_SUCCESS == result )
Expand Down
Empty file.
34 changes: 17 additions & 17 deletions scripts/templates/valddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ namespace ur_validation_layer
%endfor
)
{${th.get_initial_null_set(obj)}
auto ${th.make_pfn_name(n, tags, obj)} = context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
auto ${th.make_pfn_name(n, tags, obj)} = getContext()->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};

if( nullptr == ${th.make_pfn_name(n, tags, obj)} ) {
return ${X}_RESULT_ERROR_UNINITIALIZED;
}

if( context.enableParameterValidation )
if( getContext()->enableParameterValidation )
{
%for key, values in sorted_param_checks:
%for val in values:
Expand All @@ -80,8 +80,8 @@ namespace ur_validation_layer
is_related_create_get_retain_release_func = any(func_name in funcs for funcs in tp_input_handle_funcs.values())
%>
%if tp_input_handle_funcs and not is_related_create_get_retain_release_func:
if (context.enableLifetimeValidation && !refCountContext.isReferenceValid(${tp['name']})) {
refCountContext.logInvalidReference(${tp['name']});
if (getContext()->enableLifetimeValidation && !getContext()->refCountContext->isReferenceValid(${tp['name']})) {
getContext()->refCountContext->logInvalidReference(${tp['name']});
}
%endif
%endfor
Expand All @@ -94,26 +94,26 @@ namespace ur_validation_layer
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
%>
%if func_name in tp_handle_funcs['create']:
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
{
refCountContext.createRefCount(*${tp['name']});
getContext()->refCountContext->createRefCount(*${tp['name']});
}
%elif func_name in tp_handle_funcs['get']:
if( context.enableLeakChecking && ${tp['name']} && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking && ${tp['name']} && result == UR_RESULT_SUCCESS )
{
for (uint32_t i = ${th.param_traits.range_start(tp)}; i < ${th.param_traits.range_end(tp)}; i++) {
refCountContext.createOrIncrementRefCount(${tp['name']}[i], ${str(is_handle_to_adapter).lower()});
getContext()->refCountContext->createOrIncrementRefCount(${tp['name']}[i], ${str(is_handle_to_adapter).lower()});
}
}
%elif func_name in tp_handle_funcs['retain']:
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
{
refCountContext.incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
getContext()->refCountContext->incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%elif func_name in tp_handle_funcs['release']:
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
{
refCountContext.decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%endif
%endfor
Expand Down Expand Up @@ -141,13 +141,13 @@ namespace ur_validation_layer
%endfor
)
{
auto& dditable = ur_validation_layer::context.${n}DdiTable.${tbl['name']};
auto& dditable = ur_validation_layer::getContext()->${n}DdiTable.${tbl['name']};

if( nullptr == pDdiTable )
return ${X}_RESULT_ERROR_INVALID_NULL_POINTER;

if (UR_MAJOR_VERSION(ur_validation_layer::context.version) != UR_MAJOR_VERSION(version) ||
UR_MINOR_VERSION(ur_validation_layer::context.version) > UR_MINOR_VERSION(version))
if (UR_MAJOR_VERSION(ur_validation_layer::getContext()->version) != UR_MAJOR_VERSION(version) ||
UR_MINOR_VERSION(ur_validation_layer::getContext()->version) > UR_MINOR_VERSION(version))
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;

${x}_result_t result = ${X}_RESULT_SUCCESS;
Expand Down Expand Up @@ -212,8 +212,8 @@ namespace ur_validation_layer
${x}_result_t result = ${X}_RESULT_SUCCESS;

if (enableLeakChecking) {
refCountContext.logInvalidReferences();
refCountContext.clear();
getContext()->refCountContext->logInvalidReferences();
getContext()->refCountContext->clear();
}
return result;
}
Expand Down
5 changes: 5 additions & 0 deletions source/common/ur_singleton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ template <typename singleton_tn, typename key_tn> class singleton_factory_t {
std::lock_guard<std::mutex> lk(mut);
map.erase(getKey(key));
}

void clear() {
std::lock_guard<std::mutex> lk(mut);
map.clear();
}
};

#endif /* UR_SINGLETON_H */
Loading

0 comments on commit 4d1b56e

Please sign in to comment.