diff --git a/source/adapters/level_zero/command_buffer.cpp b/source/adapters/level_zero/command_buffer.cpp index c4d9614159..701ddab569 100644 --- a/source/adapters/level_zero/command_buffer.cpp +++ b/source/adapters/level_zero/command_buffer.cpp @@ -949,41 +949,53 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer, auto Platform = CommandBuffer->Context->getPlatform(); auto ZeDevice = CommandBuffer->Device->ZeDevice; + ze_command_list_handle_t ZeCommandList = + CommandBuffer->ZeComputeCommandListTranslated; + if (Platform->ZeMutableCmdListExt.LoaderExtension) { + ZeCommandList = CommandBuffer->ZeComputeCommandList; + } if (NumKernelAlternatives > 0) { ZeMutableCommandDesc.flags |= ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION; - std::vector TranslatedKernelHandles( - NumKernelAlternatives + 1, nullptr); + std::vector KernelHandles(NumKernelAlternatives + 1, + nullptr); ze_kernel_handle_t ZeMainKernel{}; UR_CALL(getZeKernel(ZeDevice, Kernel, &ZeMainKernel)); - // Translate main kernel first - ZE2UR_CALL(zelLoaderTranslateHandle, - (ZEL_HANDLE_KERNEL, ZeMainKernel, - (void **)&TranslatedKernelHandles[0])); + if (Platform->ZeMutableCmdListExt.LoaderExtension) { + KernelHandles[0] = ZeMainKernel; + } else { + // If the L0 loader is not aware of the MCL extension, the main kernel + // handle needs to be translated. + ZE2UR_CALL(zelLoaderTranslateHandle, + (ZEL_HANDLE_KERNEL, ZeMainKernel, (void **)&KernelHandles[0])); + } for (size_t i = 0; i < NumKernelAlternatives; i++) { ze_kernel_handle_t ZeAltKernel{}; UR_CALL(getZeKernel(ZeDevice, KernelAlternatives[i], &ZeAltKernel)); - ZE2UR_CALL(zelLoaderTranslateHandle, - (ZEL_HANDLE_KERNEL, ZeAltKernel, - (void **)&TranslatedKernelHandles[i + 1])); + if (Platform->ZeMutableCmdListExt.LoaderExtension) { + KernelHandles[i + 1] = ZeAltKernel; + } else { + // If the L0 loader is not aware of the MCL extension, the kernel + // alternatives need to be translated. + ZE2UR_CALL(zelLoaderTranslateHandle, (ZEL_HANDLE_KERNEL, ZeAltKernel, + (void **)&KernelHandles[i + 1])); + } } ZE2UR_CALL(Platform->ZeMutableCmdListExt .zexCommandListGetNextCommandIdWithKernelsExp, - (CommandBuffer->ZeComputeCommandListTranslated, - &ZeMutableCommandDesc, NumKernelAlternatives + 1, - TranslatedKernelHandles.data(), &CommandId)); + (ZeCommandList, &ZeMutableCommandDesc, NumKernelAlternatives + 1, + KernelHandles.data(), &CommandId)); } else { ZE2UR_CALL(Platform->ZeMutableCmdListExt.zexCommandListGetNextCommandIdExp, - (CommandBuffer->ZeComputeCommandListTranslated, - &ZeMutableCommandDesc, &CommandId)); + (ZeCommandList, &ZeMutableCommandDesc, &CommandId)); } DEBUG_LOG(CommandId); @@ -1863,17 +1875,22 @@ ur_result_t updateKernelCommand( ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel; if (NewKernel && Command->Kernel != NewKernel) { + ze_kernel_handle_t KernelHandle{}; ze_kernel_handle_t ZeNewKernel{}; UR_CALL(getZeKernel(ZeDevice, NewKernel, &ZeNewKernel)); - ze_kernel_handle_t ZeKernelTranslated = nullptr; - ZE2UR_CALL(zelLoaderTranslateHandle, - (ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&ZeKernelTranslated)); + ze_command_list_handle_t ZeCommandList = + CommandBuffer->ZeComputeCommandList; + KernelHandle = ZeNewKernel; + if (!Platform->ZeMutableCmdListExt.LoaderExtension) { + ZeCommandList = CommandBuffer->ZeComputeCommandListTranslated; + ZE2UR_CALL(zelLoaderTranslateHandle, + (ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&KernelHandle)); + } ZE2UR_CALL(Platform->ZeMutableCmdListExt .zexCommandListUpdateMutableCommandKernelsExp, - (CommandBuffer->ZeComputeCommandListTranslated, 1, - &Command->CommandId, &ZeKernelTranslated)); + (ZeCommandList, 1, &Command->CommandId, &KernelHandle)); // Set current kernel to be the new kernel Command->Kernel = NewKernel; } @@ -2079,9 +2096,15 @@ ur_result_t updateKernelCommand( MutableCommandDesc.pNext = NextDesc; MutableCommandDesc.flags = 0; + ze_command_list_handle_t ZeCommandList = + CommandBuffer->ZeComputeCommandListTranslated; + if (Platform->ZeMutableCmdListExt.LoaderExtension) { + ZeCommandList = CommandBuffer->ZeComputeCommandList; + } + ZE2UR_CALL( Platform->ZeMutableCmdListExt.zexCommandListUpdateMutableCommandsExp, - (CommandBuffer->ZeComputeCommandListTranslated, &MutableCommandDesc)); + (ZeCommandList, &MutableCommandDesc)); return UR_RESULT_SUCCESS; } diff --git a/source/adapters/level_zero/platform.cpp b/source/adapters/level_zero/platform.cpp index 26b5a03ed6..0848facdc9 100644 --- a/source/adapters/level_zero/platform.cpp +++ b/source/adapters/level_zero/platform.cpp @@ -386,6 +386,7 @@ ur_result_t ur_platform_handle_t_::initialize() { ZeMutableCmdListExt.Supported |= ZeMutableCmdListExt.zexCommandListGetNextCommandIdWithKernelsExp != nullptr; + ZeMutableCmdListExt.LoaderExtension = true; } else { ZeMutableCmdListExt.Supported |= (ZE_CALL_NOCHECK( diff --git a/source/adapters/level_zero/platform.hpp b/source/adapters/level_zero/platform.hpp index 0faa122651..1381f51bca 100644 --- a/source/adapters/level_zero/platform.hpp +++ b/source/adapters/level_zero/platform.hpp @@ -96,6 +96,12 @@ struct ur_platform_handle_t_ : public _ur_platform { // associated with particular Level Zero driver, store this extension here. struct ZeMutableCmdListExtension { bool Supported = false; + // If LoaderExtension is true, the L0 loader is aware of the MCL extension. + // If it is false, the extension has to be loaded directly from the driver + // using zeDriverGetExtensionFunctionAddress. If it is loaded directly from + // the driver, any handles passed to it must be translated using + // zelLoaderTranslateHandle. + bool LoaderExtension = false; ze_result_t (*zexCommandListGetNextCommandIdExp)( ze_command_list_handle_t, const ze_mutable_command_id_exp_desc_t *, uint64_t *) = nullptr;