From df47657f17b9726f61251cbe2256c03f2534501d Mon Sep 17 00:00:00 2001 From: omarahmed1111 Date: Thu, 29 Aug 2024 16:21:15 +0100 Subject: [PATCH] ensure UR will clear context on unloading --- source/adapters/level_zero/CMakeLists.txt | 5 +++ .../level_zero/adapter_lib_init_windows.cpp | 38 +++++++++++++++++++ source/adapters/level_zero/context.cpp | 25 +++++++++++- source/adapters/level_zero/context.hpp | 2 + source/adapters/level_zero/kernel.cpp | 18 +++++++++ 5 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 source/adapters/level_zero/adapter_lib_init_windows.cpp diff --git a/source/adapters/level_zero/CMakeLists.txt b/source/adapters/level_zero/CMakeLists.txt index 05bf05e0a7..923f1ceb62 100644 --- a/source/adapters/level_zero/CMakeLists.txt +++ b/source/adapters/level_zero/CMakeLists.txt @@ -140,6 +140,11 @@ if(UR_BUILD_ADAPTER_L0) PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/adapter_lib_init_linux.cpp ) + else() + target_sources(ur_adapter_level_zero + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/adapter_lib_init_windows.cpp + ) endif() # TODO: fix level_zero adapter conversion warnings diff --git a/source/adapters/level_zero/adapter_lib_init_windows.cpp b/source/adapters/level_zero/adapter_lib_init_windows.cpp new file mode 100644 index 0000000000..9f6150a2dc --- /dev/null +++ b/source/adapters/level_zero/adapter_lib_init_windows.cpp @@ -0,0 +1,38 @@ +//===--------- adapter_lib_init_linux.cpp - Level Zero Adapter ------------===// +// +// Copyright (C) 2023 Intel Corporation +// +// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM +// Exceptions. See LICENSE.TXT +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "adapter.hpp" +#include "ur_level_zero.hpp" + +#include + +BOOL WINAPI DllMain(HINSTANCE hinstDLL, // handle to DLL module + DWORD fdwReason, // reason for calling function + LPVOID lpReserved) // reserved +{ + switch (fdwReason) { + case DLL_PROCESS_ATTACH: + break; + case DLL_PROCESS_DETACH: { + const auto *platforms = GlobalAdapter->PlatformCache->get_value(); + for (const auto &p : *platforms) { + while (!p->Contexts.empty()) { + UR_CALL(urContextRelease(p->Contexts.front())); + } + } + break; + } + case DLL_THREAD_ATTACH: + break; + case DLL_THREAD_DETACH: + break; + } + return TRUE; +} diff --git a/source/adapters/level_zero/context.cpp b/source/adapters/level_zero/context.cpp index 452189d038..5a0beb9e65 100644 --- a/source/adapters/level_zero/context.cpp +++ b/source/adapters/level_zero/context.cpp @@ -40,10 +40,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate( Context->initialize(); *RetContext = reinterpret_cast(Context); +#ifdef _WIN32 + std::scoped_lock Lock(Platform->ContextsMutex); + auto It = std::find(Platform->Contexts.begin(), Platform->Contexts.end(), + *RetContext); + if (It == Platform->Contexts.end()) { + Platform->Contexts.push_back(*RetContext); + } +#else if (IndirectAccessTrackingEnabled) { std::scoped_lock Lock(Platform->ContextsMutex); Platform->Contexts.push_back(*RetContext); } +#endif } catch (const std::bad_alloc &) { return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; } catch (...) { @@ -355,13 +364,21 @@ ur_result_t ContextReleaseHelper(ur_context_handle_t Context) { if (!Context->RefCount.decrementAndTest()) return UR_RESULT_SUCCESS; - if (IndirectAccessTrackingEnabled) { + auto DeleteFromContextsCache = [&]() { ur_platform_handle_t Plt = Context->getPlatform(); auto &Contexts = Plt->Contexts; auto It = std::find(Contexts.begin(), Contexts.end(), Context); if (It != Contexts.end()) Contexts.erase(It); + }; + +#ifdef _WIN32 + DeleteFromContextsCache(); +#else + if (IndirectAccessTrackingEnabled) { + DeleteFromContextsCache(); } +#endif ze_context_handle_t DestroyZeContext = Context->OwnNativeHandle ? Context->ZeContext : nullptr; @@ -451,6 +468,12 @@ ur_result_t ur_context_handle_t_::finalize() { } } } + + for (auto &kernel : KernelsCache) { + UR_CALL(urKernelRelease(kernel)); + } + KernelsCache.clear(); + return UR_RESULT_SUCCESS; } diff --git a/source/adapters/level_zero/context.hpp b/source/adapters/level_zero/context.hpp index a1212f0698..f21f62b193 100644 --- a/source/adapters/level_zero/context.hpp +++ b/source/adapters/level_zero/context.hpp @@ -175,6 +175,8 @@ struct ur_context_handle_t_ : _ur_object { std::vector> EventCachesDeviceMap{4}; + std::vector KernelsCache; + // Initialize the PI context. ur_result_t initialize(); diff --git a/source/adapters/level_zero/kernel.cpp b/source/adapters/level_zero/kernel.cpp index 3469620b71..25b35bebc2 100644 --- a/source/adapters/level_zero/kernel.cpp +++ b/source/adapters/level_zero/kernel.cpp @@ -591,6 +591,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate( try { ur_kernel_handle_t_ *UrKernel = new ur_kernel_handle_t_(true, Program); *RetKernel = reinterpret_cast(UrKernel); + +#ifdef _WIN32 + auto &Context = Program->Context; + auto It = std::find(Context->KernelsCache.begin(), + Context->KernelsCache.end(), *RetKernel); + if (It == Context->KernelsCache.end()) { + Context->KernelsCache.push_back(*RetKernel); + } +#endif + } catch (const std::bad_alloc &) { return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; } catch (...) { @@ -902,6 +912,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease( return ze2urResult(ZeResult); } } + +#ifdef _WIN32 + auto &KernelsCache = Kernel->Program->Context->KernelsCache; + auto It = std::find(KernelsCache.begin(), KernelsCache.end(), Kernel); + if (It != KernelsCache.end()) + KernelsCache.erase(It); +#endif + Kernel->ZeKernelMap.clear(); if (IndirectAccessTrackingEnabled) { UR_CALL(urContextRelease(KernelProgram->Context));