diff --git a/source/adapters/level_zero/program.cpp b/source/adapters/level_zero/program.cpp index 4c77d14f33..e7ac31f769 100644 --- a/source/adapters/level_zero/program.cpp +++ b/source/adapters/level_zero/program.cpp @@ -81,7 +81,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary( ur_program_handle_t *Program ///< [out] pointer to handle of Program object created. ) { - std::ignore = Device; std::ignore = Properties; // In OpenCL, clCreateProgramWithBinary() can be used to load any of the // following: "program executable", "compiled program", or "library of @@ -97,6 +96,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary( try { ur_program_handle_t_ *UrProgram = new ur_program_handle_t_( ur_program_handle_t_::Native, Context, Binary, Size); + UrProgram->BinaryDeviceHandle = Device; *Program = reinterpret_cast(UrProgram); } catch (const std::bad_alloc &) { return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; @@ -601,12 +601,37 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetInfo( return ReturnValue(uint32_t{Program->RefCount.load()}); case UR_PROGRAM_INFO_CONTEXT: return ReturnValue(Program->Context); - case UR_PROGRAM_INFO_NUM_DEVICES: - // TODO: return true number of devices this program exists for. - return ReturnValue(uint32_t{1}); - case UR_PROGRAM_INFO_DEVICES: - // TODO: return all devices this program exists for. - return ReturnValue(Program->Context->Devices[0]); + case UR_PROGRAM_INFO_NUM_DEVICES: { + if (Program->BinaryDeviceHandle != nullptr || + Program->ZeModuleMap.empty()) { + return ReturnValue(uint32_t{1}); + } + return ReturnValue(static_cast(Program->ZeModuleMap.size())); + } + case UR_PROGRAM_INFO_DEVICES: { + if (Program->BinaryDeviceHandle != nullptr) { + return ReturnValue(Program->BinaryDeviceHandle); + } + if (Program->ZeModuleMap.empty()) { + // TODO: urProgramCreateWithNativeHandle does + // not give us the devices. Return first available. + return ReturnValue(Program->Context->Devices[0]); + } + std::vector devices; + for (const auto &entry : Program->ZeModuleMap) { + const auto &devs = Program->Context->Devices; + auto it = std::find_if(devs.begin(), devs.end(), + [entry](const ur_device_handle_t &d) { + return d->ZeDevice == entry.first; + }); + if (it != devs.end()) { + devices.push_back(*it); + } else { + ur::unreachable(); + } + } + return ReturnValue(devices.data(), devices.size()); + } case UR_PROGRAM_INFO_BINARY_SIZES: { std::shared_lock Guard(Program->Mutex); size_t SzBinary; @@ -827,6 +852,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithNativeHandle( // represent a fully linked executable (state Exe) and not an unlinked // executable (state Object). + // TODO: this entry point does not specify which devices this program was + // compiled for, which means that the adapter has no way of knowing what to + // return in UR_PROGRAM_INFO_DEVICES. One hacky solution is to recompile the + // program for ALL devices in the context with the use of + // zeModuleGetNativeBinary. + try { ur_program_handle_t_ *UrProgram = new ur_program_handle_t_(ur_program_handle_t_::Exe, Context, ZeModule, diff --git a/source/adapters/level_zero/program.hpp b/source/adapters/level_zero/program.hpp index 8d148c8fa2..21f0c0d108 100644 --- a/source/adapters/level_zero/program.hpp +++ b/source/adapters/level_zero/program.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "ur_api.h" struct ur_program_handle_t_ : _ur_object { // ur_program_handle_t_() {} @@ -147,4 +148,7 @@ struct ur_program_handle_t_ : _ur_object { // Program has been built. std::unordered_map ZeBuildLogMap; + + // Device handle provided at creation time with urProgramCreateWithBinary. + ur_device_handle_t BinaryDeviceHandle = nullptr; };