Skip to content

Commit

Permalink
ensure UR will clear context on unloading
Browse files Browse the repository at this point in the history
  • Loading branch information
omarahmed1111 committed Aug 30, 2024
1 parent 70e4cdc commit df47657
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 1 deletion.
5 changes: 5 additions & 0 deletions source/adapters/level_zero/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ if(UR_BUILD_ADAPTER_L0)
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/adapter_lib_init_linux.cpp
)
else()
target_sources(ur_adapter_level_zero
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/adapter_lib_init_windows.cpp
)
endif()

# TODO: fix level_zero adapter conversion warnings
Expand Down
38 changes: 38 additions & 0 deletions source/adapters/level_zero/adapter_lib_init_windows.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===--------- adapter_lib_init_linux.cpp - Level Zero Adapter ------------===//
//
// Copyright (C) 2023 Intel Corporation
//
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
// Exceptions. See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "adapter.hpp"
#include "ur_level_zero.hpp"

#include <windows.h>

BOOL WINAPI DllMain(HINSTANCE hinstDLL, // handle to DLL module
DWORD fdwReason, // reason for calling function
LPVOID lpReserved) // reserved
{
switch (fdwReason) {
case DLL_PROCESS_ATTACH:
break;
case DLL_PROCESS_DETACH: {
const auto *platforms = GlobalAdapter->PlatformCache->get_value();
for (const auto &p : *platforms) {
while (!p->Contexts.empty()) {
UR_CALL(urContextRelease(p->Contexts.front()));
}
}
break;
}
case DLL_THREAD_ATTACH:
break;
case DLL_THREAD_DETACH:
break;
}
return TRUE;
}
25 changes: 24 additions & 1 deletion source/adapters/level_zero/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(

Context->initialize();
*RetContext = reinterpret_cast<ur_context_handle_t>(Context);
#ifdef _WIN32
std::scoped_lock<ur_shared_mutex> Lock(Platform->ContextsMutex);
auto It = std::find(Platform->Contexts.begin(), Platform->Contexts.end(),
*RetContext);
if (It == Platform->Contexts.end()) {
Platform->Contexts.push_back(*RetContext);
}
#else
if (IndirectAccessTrackingEnabled) {
std::scoped_lock<ur_shared_mutex> Lock(Platform->ContextsMutex);
Platform->Contexts.push_back(*RetContext);
}
#endif
} catch (const std::bad_alloc &) {
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
} catch (...) {
Expand Down Expand Up @@ -355,13 +364,21 @@ ur_result_t ContextReleaseHelper(ur_context_handle_t Context) {
if (!Context->RefCount.decrementAndTest())
return UR_RESULT_SUCCESS;

if (IndirectAccessTrackingEnabled) {
auto DeleteFromContextsCache = [&]() {
ur_platform_handle_t Plt = Context->getPlatform();
auto &Contexts = Plt->Contexts;
auto It = std::find(Contexts.begin(), Contexts.end(), Context);
if (It != Contexts.end())
Contexts.erase(It);
};

#ifdef _WIN32
DeleteFromContextsCache();
#else
if (IndirectAccessTrackingEnabled) {
DeleteFromContextsCache();
}
#endif
ze_context_handle_t DestroyZeContext =
Context->OwnNativeHandle ? Context->ZeContext : nullptr;

Expand Down Expand Up @@ -451,6 +468,12 @@ ur_result_t ur_context_handle_t_::finalize() {
}
}
}

for (auto &kernel : KernelsCache) {
UR_CALL(urKernelRelease(kernel));
}
KernelsCache.clear();

return UR_RESULT_SUCCESS;
}

Expand Down
2 changes: 2 additions & 0 deletions source/adapters/level_zero/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ struct ur_context_handle_t_ : _ur_object {
std::vector<std::unordered_map<ur_device_handle_t, size_t>>
EventCachesDeviceMap{4};

std::vector<ur_kernel_handle_t> KernelsCache;

// Initialize the PI context.
ur_result_t initialize();

Expand Down
18 changes: 18 additions & 0 deletions source/adapters/level_zero/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate(
try {
ur_kernel_handle_t_ *UrKernel = new ur_kernel_handle_t_(true, Program);
*RetKernel = reinterpret_cast<ur_kernel_handle_t>(UrKernel);

#ifdef _WIN32
auto &Context = Program->Context;
auto It = std::find(Context->KernelsCache.begin(),
Context->KernelsCache.end(), *RetKernel);
if (It == Context->KernelsCache.end()) {
Context->KernelsCache.push_back(*RetKernel);
}
#endif

} catch (const std::bad_alloc &) {
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
} catch (...) {
Expand Down Expand Up @@ -902,6 +912,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(
return ze2urResult(ZeResult);
}
}

#ifdef _WIN32
auto &KernelsCache = Kernel->Program->Context->KernelsCache;
auto It = std::find(KernelsCache.begin(), KernelsCache.end(), Kernel);
if (It != KernelsCache.end())
KernelsCache.erase(It);
#endif

Kernel->ZeKernelMap.clear();
if (IndirectAccessTrackingEnabled) {
UR_CALL(urContextRelease(KernelProgram->Context));
Expand Down

0 comments on commit df47657

Please sign in to comment.