Skip to content

Commit

Permalink
[L0] delete platforms on teardown
Browse files Browse the repository at this point in the history
  • Loading branch information
pbalcer committed Mar 7, 2024
1 parent 8499b57 commit 58a0513
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 26 deletions.
70 changes: 51 additions & 19 deletions source/adapters/level_zero/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,33 @@
//===----------------------------------------------------------------------===//

#include "adapter.hpp"
#include "ur_api.h"
#include "ur_level_zero.hpp"

static std::atomic<ur_adapter_handle_t_ *> GlobalAdapter = nullptr;

ur_adapter_handle_t getAdapter(bool initIfNull) {
ur_adapter_handle_t local = GlobalAdapter.load(std::memory_order_acquire);
if (local == nullptr && initIfNull) {
auto newAdapter = new ur_adapter_handle_t_();
if (!GlobalAdapter.compare_exchange_strong(local, newAdapter, std::memory_order_acq_rel, std::memory_order_acquire)) {
delete newAdapter;
}
return getAdapter(initIfNull);
}

return local;
}

void teardownAdapter() {
ur_adapter_handle_t local = GlobalAdapter.load(std::memory_order_acquire);
if (local != nullptr) {
if (GlobalAdapter.compare_exchange_strong(local, nullptr, std::memory_order_acq_rel, std::memory_order_acquire)) {
delete local;
}
}
}

ur_result_t initPlatforms(PlatformVec &platforms) noexcept try {
uint32_t ZeDriverCount = 0;
ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, nullptr));
Expand All @@ -37,8 +62,7 @@ ur_result_t initPlatforms(PlatformVec &platforms) noexcept try {
ur_result_t adapterStateInit() { return UR_RESULT_SUCCESS; }

ur_adapter_handle_t_::ur_adapter_handle_t_() {

Adapter.PlatformCache.Compute = [](Result<PlatformVec> &result) {
PlatformCache.Compute = [](Result<PlatformVec> &result) {
static std::once_flag ZeCallCountInitialized;
try {
std::call_once(ZeCallCountInitialized, []() {
Expand All @@ -51,8 +75,10 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
return;
}

auto Adapter = getAdapter();

// initialize level zero only once.
if (Adapter.ZeResult == std::nullopt) {
if (Adapter->ZeResult == std::nullopt) {
// Setting these environment variables before running zeInit will enable
// the validation layer in the Level Zero loader.
if (UrL0Debug & UR_L0_DEBUG_VALIDATION) {
Expand All @@ -71,20 +97,20 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
// We must only initialize the driver once, even if urPlatformGet() is
// called multiple times. Declaring the return value as "static" ensures
// it's only called once.
Adapter.ZeResult = ZE_CALL_NOCHECK(zeInit, (ZE_INIT_FLAG_GPU_ONLY));
Adapter->ZeResult = ZE_CALL_NOCHECK(zeInit, (ZE_INIT_FLAG_GPU_ONLY));
}
assert(Adapter.ZeResult !=
assert(Adapter->ZeResult !=
std::nullopt); // verify that level-zero is initialized
PlatformVec platforms;

// Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms.
if (*Adapter.ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
if (*Adapter->ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
result = std::move(platforms);
return;
}
if (*Adapter.ZeResult != ZE_RESULT_SUCCESS) {
if (*Adapter->ZeResult != ZE_RESULT_SUCCESS) {
urPrint("zeInit: Level Zero initialization failure\n");
result = ze2urResult(*Adapter.ZeResult);
result = ze2urResult(*Adapter->ZeResult);
return;
}

Expand All @@ -97,8 +123,6 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
};
}

ur_adapter_handle_t_ Adapter{};

ur_result_t adapterStateTeardown() {
bool LeakFound = false;

Expand Down Expand Up @@ -185,6 +209,8 @@ ur_result_t adapterStateTeardown() {
if (LeakFound)
return UR_RESULT_ERROR_INVALID_MEM_OBJECT;

teardownAdapter();

return UR_RESULT_SUCCESS;
}

Expand All @@ -203,11 +229,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
///< adapters available.
) {
if (NumEntries > 0 && Adapters) {
std::lock_guard<std::mutex> Lock{Adapter.Mutex};
if (Adapter.RefCount++ == 0) {
auto Adapter = getAdapter();
std::lock_guard<std::mutex> Lock{Adapter->Mutex};
if (Adapter->RefCount++ == 0) {
adapterStateInit();
}
*Adapters = &Adapter;
*Adapters = Adapter;
}

if (NumAdapters) {
Expand All @@ -218,17 +245,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
std::lock_guard<std::mutex> Lock{Adapter.Mutex};
if (--Adapter.RefCount == 0) {
return adapterStateTeardown();
auto Adapter = getAdapter(false);
if (Adapter) {
if (--Adapter->RefCount == 0) {
return adapterStateTeardown();
}
}

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
std::lock_guard<std::mutex> Lock{Adapter.Mutex};
Adapter.RefCount++;
auto Adapter = getAdapter();
std::lock_guard<std::mutex> Lock{Adapter->Mutex};
Adapter->RefCount++;

return UR_RESULT_SUCCESS;
}
Expand All @@ -253,11 +283,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
size_t *PropSizeRet) {
UrReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);

auto Adapter = getAdapter();

switch (PropName) {
case UR_ADAPTER_INFO_BACKEND:
return ReturnValue(UR_ADAPTER_BACKEND_LEVEL_ZERO);
case UR_ADAPTER_INFO_REFERENCE_COUNT:
return ReturnValue(Adapter.RefCount.load());
return ReturnValue(Adapter->RefCount.load());
default:
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/level_zero/adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ struct ur_adapter_handle_t_ {
ZeCache<Result<PlatformVec>> PlatformCache;
};

extern ur_adapter_handle_t_ Adapter;
ur_adapter_handle_t getAdapter(bool initIfNull = true);
6 changes: 4 additions & 2 deletions source/adapters/level_zero/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1441,8 +1441,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
// "NativeHandle" must already be in the cache. If it is not, this must not be
// a valid Level Zero device.

auto Adapter = getAdapter();

ur_device_handle_t Dev = nullptr;
if (const auto *platforms = Adapter.PlatformCache->get_value()) {
if (const auto *platforms = Adapter->PlatformCache->get_value()) {
for (const auto &p : *platforms) {
Dev = p->getDeviceFromNativeHandle(ZeDevice);
if (Dev) {
Expand All @@ -1453,7 +1455,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
}
}
} else {
return Adapter.PlatformCache->get_error();
return Adapter->PlatformCache->get_error();
}

if (Dev == nullptr)
Expand Down
8 changes: 5 additions & 3 deletions source/adapters/level_zero/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet(
uint32_t *NumPlatforms ///< [out][optional] returns the total number of
///< platforms available.
) {
auto Adapter = getAdapter();

// Platform handles are cached for reuse. This is to ensure consistent
// handle pointers across invocations and to improve retrieval performance.
if (const auto *cached_platforms = Adapter.PlatformCache->get_value();
if (const auto *cached_platforms = Adapter->PlatformCache->get_value();
cached_platforms) {
uint32_t nplatforms = (uint32_t)cached_platforms->size();
if (NumPlatforms) {
Expand All @@ -41,7 +43,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet(
}
}
} else {
return Adapter.PlatformCache->get_error();
return Adapter->PlatformCache->get_error();
}

return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -133,7 +135,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformCreateWithNativeHandle(
auto ZeDriver = ur_cast<ze_driver_handle_t>(NativePlatform);

uint32_t NumPlatforms = 0;
ur_adapter_handle_t AdapterHandle = &Adapter;
ur_adapter_handle_t AdapterHandle = getAdapter();
UR_CALL(urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms));

if (NumPlatforms) {
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/level_zero/queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
// Maybe this is not completely correct.
uint32_t NumEntries = 1;
ur_platform_handle_t Platform{};
ur_adapter_handle_t AdapterHandle = &Adapter;
ur_adapter_handle_t AdapterHandle = getAdapter();
UR_CALL(urPlatformGet(&AdapterHandle, 1, NumEntries, &Platform, nullptr));

ur_device_handle_t UrDevice = Device;
Expand Down

0 comments on commit 58a0513

Please sign in to comment.