diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index a295ea396cd0..fbb4e13e0534 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -221,16 +221,14 @@ class OpenCLWorkspace : public DeviceAPI { public: // type key std::string type_key; - // global platform id - cl_platform_id platform_id; - // global platform name - std::string platform_name; - // global context of this process - cl_context context{nullptr}; + // available platforms + std::vector platform_ids; + // map platform to its context + std::unordered_map contexts; // whether the workspace it initialized. bool initialized_{false}; - // the device type - std::string device_type; + // map device to platform + std::unordered_map device_to_platform; // the devices std::vector devices; // the queues @@ -248,11 +246,11 @@ class OpenCLWorkspace : public DeviceAPI { std::mutex mu; // destructor ~OpenCLWorkspace() { - if (context != nullptr) { - OPENCL_CALL(clReleaseContext(context)); + for (auto& it : contexts) { + OPENCL_CALL(clReleaseContext(it.second)); } } - // Initialzie the device. + // Initialize the device. void Init(const std::string& type_key, const std::string& device_type, const std::string& platform_name = ""); virtual void Init() { Init("opencl", "gpu"); } @@ -296,13 +294,15 @@ class OpenCLWorkspace : public DeviceAPI { OPENCL_CALL(clFinish(queue)); OPENCL_CALL(clReleaseCommandQueue(queue)); cl_int err_code; - cl_device_id did = cl::OpenCLWorkspace::Global()->devices[dev.device_id]; - auto profiling_queue = - clCreateCommandQueue(cl::OpenCLWorkspace::Global()->context, did, prop, &err_code); + cl_device_id did = cl::OpenCLWorkspace::Global()->GetCLDeviceID(dev.device_id); + cl_platform_id platform = cl::OpenCLWorkspace::Global()->device_to_platform[did]; + auto profiling_queue = clCreateCommandQueue(cl::OpenCLWorkspace::Global()->contexts[platform], + did, prop, &err_code); OPENCL_CHECK_ERROR(err_code); cl::OpenCLWorkspace::Global()->queues[dev.device_id] = profiling_queue; } + cl_device_id GetCLDeviceID(int device_id); // override device API void SetDevice(Device dev) final; void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index c53523267d66..f3eb8d83a210 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -110,6 +110,11 @@ OpenCLWorkspace* OpenCLWorkspace::Global() { return inst; } +cl_device_id OpenCLWorkspace::GetCLDeviceID(int device_id) { + ICHECK_LT(device_id, devices.size()) << "Invalid device id " << device_id << ". " << GetError(); + return devices[device_id]; +} + void OpenCLWorkspace::SetDevice(Device dev) { GetThreadEntry()->device.device_id = dev.device_id; } void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { @@ -119,14 +124,14 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) *rv = static_cast(index < devices.size()); return; } - ICHECK_LT(index, devices.size()) << "Invalid device id " << index << ". " << GetError(); + cl_device_id device_id = GetCLDeviceID(index); switch (kind) { case kExist: break; case kMaxThreadsPerBlock: { size_t value; - OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), - &value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &value, + nullptr)); *rv = static_cast(value); break; } @@ -142,21 +147,21 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) } case kMaxSharedMemoryPerBlock: { cl_ulong value; - OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_LOCAL_MEM_SIZE, sizeof(cl_ulong), - &value, nullptr)); + OPENCL_CALL( + clGetDeviceInfo(device_id, CL_DEVICE_LOCAL_MEM_SIZE, sizeof(cl_ulong), &value, nullptr)); *rv = static_cast(value); break; } case kComputeVersion: - *rv = GetOpenCLVersion(devices[index]); + *rv = GetOpenCLVersion(device_id); break; case kDeviceName: - *rv = GetDeviceInfo(devices[index], CL_DEVICE_NAME); + *rv = GetDeviceInfo(device_id, CL_DEVICE_NAME); break; case kMaxClockRate: { cl_uint value; - OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_CLOCK_FREQUENCY, sizeof(cl_uint), - &value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_MAX_CLOCK_FREQUENCY, sizeof(cl_uint), &value, + nullptr)); // OpenCL returns the clock rate in MHz, while CUDA/ROCm return the // clock rate in kHz. Converting to the same units for each. *rv = static_cast(value * 1000); @@ -164,15 +169,15 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) } case kMultiProcessorCount: { cl_uint value; - OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_COMPUTE_UNITS, sizeof(cl_uint), - &value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_MAX_COMPUTE_UNITS, sizeof(cl_uint), &value, + nullptr)); *rv = static_cast(value); break; } case kMaxThreadDimensions: { size_t dims[3]; - OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, - nullptr)); + OPENCL_CALL( + clGetDeviceInfo(device_id, CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, nullptr)); std::stringstream ss; // use json string to return multiple int values; ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; @@ -189,8 +194,7 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) } case kDriverVersion: { char value[128] = {0}; - OPENCL_CALL( - clGetDeviceInfo(devices[index], CL_DRIVER_VERSION, sizeof(value) - 1, value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(device_id, CL_DRIVER_VERSION, sizeof(value) - 1, value, nullptr)); *rv = std::string(value); break; } @@ -211,14 +215,16 @@ void* OpenCLWorkspace::CreateHostPtrIfEnabled(cl::BufferDescriptor* desc, Device void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t size, size_t alignment, DLDataType type_hint) { this->Init(); - ICHECK(context != nullptr) << "No OpenCL device. " << GetError(); + cl_device_id device_id = GetCLDeviceID(dev.device_id); + auto platform = device_to_platform[device_id]; cl_int err_code; cl::BufferDescriptor* desc = new cl::BufferDescriptor; // CL_INVALID_BUFFER_SIZE if size is 0. if (size == 0) { size = 1; } - desc->buffer = clCreateBuffer(this->context, CL_MEM_CREATE_FLAGS, size, nullptr, &err_code); + desc->buffer = + clCreateBuffer(this->contexts[platform], CL_MEM_CREATE_FLAGS, size, nullptr, &err_code); desc->layout = cl::BufferDescriptor::MemoryLayout::kBuffer1D; OPENCL_CHECK_ERROR(err_code); return CreateHostPtrIfEnabled(desc, dev, size); @@ -265,13 +271,14 @@ void OpenCLWorkspace::FreeDataSpace(Device dev, void* ptr) { cl_mem OpenCLWorkspace::AllocTexture(Device dev, size_t width, size_t height, DLDataType type_hint) { this->Init(); - ICHECK(context != nullptr) << "No OpenCL device. " << GetError(); + cl_device_id device_id = GetCLDeviceID(dev.device_id); + auto platform = device_to_platform[device_id]; cl_int err_code; cl_channel_type cl_type = DTypeToOpenCLChannelType(type_hint); cl_image_format format = {CL_RGBA, cl_type}; cl_image_desc descriptor = {CL_MEM_OBJECT_IMAGE2D, width, height, 0, 0, 0, 0, 0, 0}; - cl_mem mptr = - clCreateImage(this->context, CL_MEM_CREATE_FLAGS, &format, &descriptor, nullptr, &err_code); + cl_mem mptr = clCreateImage(this->contexts[platform], CL_MEM_CREATE_FLAGS, &format, &descriptor, + nullptr, &err_code); OPENCL_CHECK_ERROR(err_code); return mptr; } @@ -445,7 +452,6 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic if (initialized_) return; std::lock_guard lock(this->mu); if (initialized_) return; - if (context != nullptr) return; this->type_key = type_key; // matched platforms std::vector platform_ids = cl::GetPlatformIDs(); @@ -453,64 +459,69 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic LOG(WARNING) << "No OpenCL platform matched given existing options ..."; return; } - this->platform_id = nullptr; - for (auto platform_id : platform_ids) { - if (!MatchPlatformInfo(platform_id, CL_PLATFORM_NAME, platform_name)) { - continue; - } - std::vector devices_matched = cl::GetDeviceIDs(platform_id, device_type); - if ((devices_matched.size() == 0) && (device_type == "gpu")) { - LOG(WARNING) << "Using CPU OpenCL device"; - devices_matched = cl::GetDeviceIDs(platform_id, "cpu"); - } - std::vector supported_devices = {}; - auto get_version_str = [](int version) { - std::ostringstream out; - out.precision(1); - out << std::fixed << version / 100.f; - return out.str(); - }; - for (auto& device : devices_matched) { - std::string ver = GetOpenCLVersion(device); - int opencl_version = std::stod(ver) * 100; - if (opencl_version >= CL_TARGET_OPENCL_VERSION) { - supported_devices.push_back(device); - } else { - std::string dev_msg = GetDeviceInfo(device, CL_DEVICE_NAME) + - " has OpenCL version == " + get_version_str(opencl_version); - LOG(WARNING) << "TVM supports devices with OpenCL version >= " - << get_version_str(CL_TARGET_OPENCL_VERSION) << ", device " << dev_msg - << ". This device will be ignored."; - - if (noDevicesErrorMsg.empty()) { - noDevicesErrorMsg = - "Probably this error happen because TVM supports devices with OpenCL version >= " + - get_version_str(CL_TARGET_OPENCL_VERSION) + ". We found the following devices:\n"; + auto find_opencl_device = [&](const std::string& device_type, const std::string& platform_name) { + std::unordered_map> device_map; + for (auto platform_id : platform_ids) { + if (!MatchPlatformInfo(platform_id, CL_PLATFORM_NAME, platform_name)) { + continue; + } + std::vector devices_matched = cl::GetDeviceIDs(platform_id, device_type); + std::vector supported_devices = {}; + auto get_version_str = [](int version) { + std::ostringstream out; + out.precision(1); + out << std::fixed << version / 100.f; + return out.str(); + }; + for (auto& device : devices_matched) { + std::string ver = GetOpenCLVersion(device); + int opencl_version = std::stod(ver) * 100; + if (opencl_version >= CL_TARGET_OPENCL_VERSION) { + supported_devices.push_back(device); + } else { + std::string dev_msg = GetDeviceInfo(device, CL_DEVICE_NAME) + + " has OpenCL version == " + get_version_str(opencl_version); + LOG(WARNING) << "TVM supports devices with OpenCL version >= " + << get_version_str(CL_TARGET_OPENCL_VERSION) << ", device " << dev_msg + << ". This device will be ignored."; + + if (noDevicesErrorMsg.empty()) { + noDevicesErrorMsg = + "Probably this error happen because TVM supports devices with OpenCL version >= " + + get_version_str(CL_TARGET_OPENCL_VERSION) + ". We found the following devices:\n"; + } + noDevicesErrorMsg += "\t" + dev_msg + "\n"; } - noDevicesErrorMsg += "\t" + dev_msg + "\n"; + } + if (supported_devices.size()) { + device_map[platform_id] = supported_devices; } } - if (supported_devices.size() > 0) { - this->platform_id = platform_id; - this->platform_name = cl::GetPlatformInfo(platform_id, CL_PLATFORM_NAME); - this->device_type = device_type; - this->devices = supported_devices; - break; - } + return device_map; + }; + auto device_map = find_opencl_device(device_type, platform_name); + if ((device_map.size() == 0) && (device_type == "gpu")) { + LOG(WARNING) << "Using CPU OpenCL device"; + device_map = find_opencl_device("cpu", ""); } - if (this->platform_id == nullptr) { + if (device_map.empty()) { LOG(WARNING) << "No OpenCL device"; initialized_ = true; return; } - cl_int err_code; - this->context = clCreateContext(nullptr, this->devices.size(), &(this->devices[0]), nullptr, - nullptr, &err_code); - OPENCL_CHECK_ERROR(err_code); ICHECK_EQ(this->queues.size(), 0U); - for (size_t i = 0; i < this->devices.size(); ++i) { - cl_device_id did = this->devices[i]; - this->queues.push_back(clCreateCommandQueue(this->context, did, 0, &err_code)); + cl_int err_code; + for (auto& [platform, devices] : device_map) { + this->platform_ids.push_back(platform); + this->contexts[platform] = + clCreateContext(nullptr, devices.size(), &(devices[0]), nullptr, nullptr, &err_code); + this->devices.insert(this->devices.end(), devices.begin(), devices.end()); + for (size_t i = 0; i < devices.size(); ++i) { + cl_device_id did = devices[i]; + device_to_platform[did] = platform; + this->queues.push_back(clCreateCommandQueue(this->contexts[platform], did, 0, &err_code)); + OPENCL_CHECK_ERROR(err_code); + } OPENCL_CHECK_ERROR(err_code); } this->events.resize(this->devices.size()); diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index ad41a34dde4e..7c084758a456 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -51,7 +51,7 @@ class OpenCLWrappedFunc { } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { - ICHECK(w_->context != nullptr) << "No OpenCL device"; + ICHECK(w_->devices.size() > 0) << "No OpenCL device"; cl::OpenCLThreadEntry* t = w_->GetThreadEntry(); // get the kernel from thread local kernel table. if (entry_.kernel_id >= t->kernel_table.size()) { @@ -227,13 +227,16 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre const std::string& func_name, const KTRefEntry& e) { std::lock_guard lock(build_lock_); int device_id = t->device.device_id; + auto did = w->GetCLDeviceID(device_id); + auto platform = w->device_to_platform[did]; if (programs_[func_name][device_id] == nullptr) { // create program if (fmt_ == "cl") { const char* s = parsed_kernels_[func_name].c_str(); size_t len = parsed_kernels_[func_name].length(); cl_int err; - programs_[func_name][device_id] = clCreateProgramWithSource(w->context, 1, &s, &len, &err); + programs_[func_name][device_id] = + clCreateProgramWithSource(w->contexts[platform], 1, &s, &len, &err); OPENCL_CHECK_ERROR(err); } else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx") { const unsigned char* s = (const unsigned char*)data_.c_str(); @@ -241,7 +244,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre cl_int err; cl_device_id dev = w->devices[device_id]; programs_[func_name][device_id] = - clCreateProgramWithBinary(w->context, 1, &dev, &len, &s, nullptr, &err); + clCreateProgramWithBinary(w->contexts[platform], 1, &dev, &len, &s, nullptr, &err); OPENCL_CHECK_ERROR(err); } else { LOG(FATAL) << "Unknown OpenCL format " << fmt_; @@ -290,9 +293,11 @@ void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) { size_t binarySize = bin_vector.size(); const unsigned char* programBinary = bin_vector.data(); - cl_device_id dev = workspace_->devices[device_id]; - programs_[name][device_id] = clCreateProgramWithBinary( - workspace_->context, 1, &dev, &binarySize, &programBinary, &binaryStatus, &err); + cl_device_id dev = workspace_->GetCLDeviceID(device_id); + auto platform = workspace_->device_to_platform[dev]; + programs_[name][device_id] = + clCreateProgramWithBinary(workspace_->contexts[platform], 1, &dev, &binarySize, + &programBinary, &binaryStatus, &err); OPENCL_CHECK_ERROR(err); OPENCL_CHECK_ERROR(binaryStatus); diff --git a/tests/cpp-runtime/opencl/opencl_timer_test.cc b/tests/cpp-runtime/opencl/opencl_timer_test.cc index f6546c25aca5..1753300d3a09 100644 --- a/tests/cpp-runtime/opencl/opencl_timer_test.cc +++ b/tests/cpp-runtime/opencl/opencl_timer_test.cc @@ -31,22 +31,23 @@ using namespace tvm::runtime::cl; TEST(OpenCLTimerNode, nested_timers) { OpenCLWorkspace* workspace = OpenCLWorkspace::Global(); OpenCLThreadEntry* thr = workspace->GetThreadEntry(); - cl_command_queue queue = workspace->GetQueue(thr->device); int err; cl_int* tmp_buf = new cl_int[BUFF_SIZE]; int64_t nested_time_sum = 0; + auto did = workspace->GetCLDeviceID(thr->device.device_id); + auto platform = workspace->device_to_platform[did]; Timer init_timer = Timer::Start(thr->device); for (int i = 0; i < NUM_REPEAT; ++i) { Timer nested_timer = Timer::Start(thr->device); // create some events - cl_event ev = clCreateUserEvent(workspace->context, &err); + cl_event ev = clCreateUserEvent(workspace->contexts[platform], &err); OPENCL_CHECK_ERROR(err); - cl_mem cl_buf = clCreateBuffer(workspace->context, CL_MEM_READ_ONLY, BUFF_SIZE * sizeof(cl_int), - nullptr, &err); + cl_mem cl_buf = clCreateBuffer(workspace->contexts[platform], CL_MEM_READ_ONLY, + BUFF_SIZE * sizeof(cl_int), nullptr, &err); OPENCL_CHECK_ERROR(err); - queue = workspace->GetQueue(thr->device); + auto queue = workspace->GetQueue(thr->device); OPENCL_CALL(clEnqueueWriteBuffer(queue, cl_buf, false, 0, BUFF_SIZE * sizeof(cl_int), tmp_buf, 0, nullptr, &ev)); OPENCL_CALL(clReleaseMemObject(cl_buf));