Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenCL] Refactor OpenCL init function #13919

Merged
merged 1 commit into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions src/runtime/opencl/opencl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,14 @@ class OpenCLWorkspace : public DeviceAPI {
public:
// type key
std::string type_key;
// global platform id
cl_platform_id platform_id;
// global platform name
std::string platform_name;
// global context of this process
cl_context context{nullptr};
// available platforms
std::vector<cl_platform_id> platform_ids;
// map platform to its context
std::unordered_map<cl_platform_id, cl_context> contexts;
// whether the workspace it initialized.
bool initialized_{false};
// the device type
std::string device_type;
// map device to platform
std::unordered_map<cl_device_id, cl_platform_id> device_to_platform;
// the devices
std::vector<cl_device_id> devices;
// the queues
Expand All @@ -248,11 +246,11 @@ class OpenCLWorkspace : public DeviceAPI {
std::mutex mu;
// destructor
~OpenCLWorkspace() {
if (context != nullptr) {
OPENCL_CALL(clReleaseContext(context));
for (auto& it : contexts) {
OPENCL_CALL(clReleaseContext(it.second));
}
}
// Initialzie the device.
// Initialize the device.
void Init(const std::string& type_key, const std::string& device_type,
const std::string& platform_name = "");
virtual void Init() { Init("opencl", "gpu"); }
Expand Down Expand Up @@ -296,13 +294,15 @@ class OpenCLWorkspace : public DeviceAPI {
OPENCL_CALL(clFinish(queue));
OPENCL_CALL(clReleaseCommandQueue(queue));
cl_int err_code;
cl_device_id did = cl::OpenCLWorkspace::Global()->devices[dev.device_id];
auto profiling_queue =
clCreateCommandQueue(cl::OpenCLWorkspace::Global()->context, did, prop, &err_code);
cl_device_id did = cl::OpenCLWorkspace::Global()->GetCLDeviceID(dev.device_id);
cl_platform_id platform = cl::OpenCLWorkspace::Global()->device_to_platform[did];
auto profiling_queue = clCreateCommandQueue(cl::OpenCLWorkspace::Global()->contexts[platform],
did, prop, &err_code);
OPENCL_CHECK_ERROR(err_code);
cl::OpenCLWorkspace::Global()->queues[dev.device_id] = profiling_queue;
}

cl_device_id GetCLDeviceID(int device_id);
// override device API
void SetDevice(Device dev) final;
void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final;
Expand Down
151 changes: 81 additions & 70 deletions src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ OpenCLWorkspace* OpenCLWorkspace::Global() {
return inst;
}

cl_device_id OpenCLWorkspace::GetCLDeviceID(int device_id) {
ICHECK_LT(device_id, devices.size()) << "Invalid device id " << device_id << ". " << GetError();
return devices[device_id];
}

void OpenCLWorkspace::SetDevice(Device dev) { GetThreadEntry()->device.device_id = dev.device_id; }

void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) {
Expand All @@ -119,14 +124,14 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
*rv = static_cast<int>(index < devices.size());
return;
}
ICHECK_LT(index, devices.size()) << "Invalid device id " << index << ". " << GetError();
cl_device_id device_id = GetCLDeviceID(index);
switch (kind) {
case kExist:
break;
case kMaxThreadsPerBlock: {
size_t value;
OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t),
&value, nullptr));
OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &value,
nullptr));
*rv = static_cast<int64_t>(value);
break;
}
Expand All @@ -142,37 +147,37 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
}
case kMaxSharedMemoryPerBlock: {
cl_ulong value;
OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_LOCAL_MEM_SIZE, sizeof(cl_ulong),
&value, nullptr));
OPENCL_CALL(
clGetDeviceInfo(device_id, CL_DEVICE_LOCAL_MEM_SIZE, sizeof(cl_ulong), &value, nullptr));
*rv = static_cast<int64_t>(value);
break;
}
case kComputeVersion:
*rv = GetOpenCLVersion(devices[index]);
*rv = GetOpenCLVersion(device_id);
break;
case kDeviceName:
*rv = GetDeviceInfo(devices[index], CL_DEVICE_NAME);
*rv = GetDeviceInfo(device_id, CL_DEVICE_NAME);
break;
case kMaxClockRate: {
cl_uint value;
OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_CLOCK_FREQUENCY, sizeof(cl_uint),
&value, nullptr));
OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_MAX_CLOCK_FREQUENCY, sizeof(cl_uint), &value,
nullptr));
// OpenCL returns the clock rate in MHz, while CUDA/ROCm return the
// clock rate in kHz. Converting to the same units for each.
*rv = static_cast<int32_t>(value * 1000);
break;
}
case kMultiProcessorCount: {
cl_uint value;
OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_COMPUTE_UNITS, sizeof(cl_uint),
&value, nullptr));
OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_MAX_COMPUTE_UNITS, sizeof(cl_uint), &value,
nullptr));
*rv = static_cast<int32_t>(value);
break;
}
case kMaxThreadDimensions: {
size_t dims[3];
OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims,
nullptr));
OPENCL_CALL(
clGetDeviceInfo(device_id, CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, nullptr));

std::stringstream ss; // use json string to return multiple int values;
ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
Expand All @@ -189,8 +194,7 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
}
case kDriverVersion: {
char value[128] = {0};
OPENCL_CALL(
clGetDeviceInfo(devices[index], CL_DRIVER_VERSION, sizeof(value) - 1, value, nullptr));
OPENCL_CALL(clGetDeviceInfo(device_id, CL_DRIVER_VERSION, sizeof(value) - 1, value, nullptr));
*rv = std::string(value);
break;
}
Expand All @@ -211,14 +215,16 @@ void* OpenCLWorkspace::CreateHostPtrIfEnabled(cl::BufferDescriptor* desc, Device
void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t size, size_t alignment,
DLDataType type_hint) {
this->Init();
ICHECK(context != nullptr) << "No OpenCL device. " << GetError();
cl_device_id device_id = GetCLDeviceID(dev.device_id);
auto platform = device_to_platform[device_id];
cl_int err_code;
cl::BufferDescriptor* desc = new cl::BufferDescriptor;
// CL_INVALID_BUFFER_SIZE if size is 0.
if (size == 0) {
size = 1;
}
desc->buffer = clCreateBuffer(this->context, CL_MEM_CREATE_FLAGS, size, nullptr, &err_code);
desc->buffer =
clCreateBuffer(this->contexts[platform], CL_MEM_CREATE_FLAGS, size, nullptr, &err_code);
desc->layout = cl::BufferDescriptor::MemoryLayout::kBuffer1D;
OPENCL_CHECK_ERROR(err_code);
return CreateHostPtrIfEnabled(desc, dev, size);
Expand Down Expand Up @@ -265,13 +271,14 @@ void OpenCLWorkspace::FreeDataSpace(Device dev, void* ptr) {
cl_mem OpenCLWorkspace::AllocTexture(Device dev, size_t width, size_t height,
DLDataType type_hint) {
this->Init();
ICHECK(context != nullptr) << "No OpenCL device. " << GetError();
cl_device_id device_id = GetCLDeviceID(dev.device_id);
auto platform = device_to_platform[device_id];
cl_int err_code;
cl_channel_type cl_type = DTypeToOpenCLChannelType(type_hint);
cl_image_format format = {CL_RGBA, cl_type};
cl_image_desc descriptor = {CL_MEM_OBJECT_IMAGE2D, width, height, 0, 0, 0, 0, 0, 0};
cl_mem mptr =
clCreateImage(this->context, CL_MEM_CREATE_FLAGS, &format, &descriptor, nullptr, &err_code);
cl_mem mptr = clCreateImage(this->contexts[platform], CL_MEM_CREATE_FLAGS, &format, &descriptor,
nullptr, &err_code);
OPENCL_CHECK_ERROR(err_code);
return mptr;
}
Expand Down Expand Up @@ -445,72 +452,76 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
if (initialized_) return;
std::lock_guard<std::mutex> lock(this->mu);
if (initialized_) return;
if (context != nullptr) return;
this->type_key = type_key;
// matched platforms
std::vector<cl_platform_id> platform_ids = cl::GetPlatformIDs();
if (platform_ids.size() == 0) {
LOG(WARNING) << "No OpenCL platform matched given existing options ...";
return;
}
this->platform_id = nullptr;
for (auto platform_id : platform_ids) {
if (!MatchPlatformInfo(platform_id, CL_PLATFORM_NAME, platform_name)) {
continue;
}
std::vector<cl_device_id> devices_matched = cl::GetDeviceIDs(platform_id, device_type);
if ((devices_matched.size() == 0) && (device_type == "gpu")) {
LOG(WARNING) << "Using CPU OpenCL device";
devices_matched = cl::GetDeviceIDs(platform_id, "cpu");
}
std::vector<cl_device_id> supported_devices = {};
auto get_version_str = [](int version) {
std::ostringstream out;
out.precision(1);
out << std::fixed << version / 100.f;
return out.str();
};
for (auto& device : devices_matched) {
std::string ver = GetOpenCLVersion(device);
int opencl_version = std::stod(ver) * 100;
if (opencl_version >= CL_TARGET_OPENCL_VERSION) {
supported_devices.push_back(device);
} else {
std::string dev_msg = GetDeviceInfo(device, CL_DEVICE_NAME) +
" has OpenCL version == " + get_version_str(opencl_version);
LOG(WARNING) << "TVM supports devices with OpenCL version >= "
<< get_version_str(CL_TARGET_OPENCL_VERSION) << ", device " << dev_msg
<< ". This device will be ignored.";

if (noDevicesErrorMsg.empty()) {
noDevicesErrorMsg =
"Probably this error happen because TVM supports devices with OpenCL version >= " +
get_version_str(CL_TARGET_OPENCL_VERSION) + ". We found the following devices:\n";
auto find_opencl_device = [&](const std::string& device_type, const std::string& platform_name) {
std::unordered_map<cl_platform_id, std::vector<cl_device_id>> device_map;
for (auto platform_id : platform_ids) {
if (!MatchPlatformInfo(platform_id, CL_PLATFORM_NAME, platform_name)) {
continue;
}
std::vector<cl_device_id> devices_matched = cl::GetDeviceIDs(platform_id, device_type);
std::vector<cl_device_id> supported_devices = {};
auto get_version_str = [](int version) {
std::ostringstream out;
out.precision(1);
out << std::fixed << version / 100.f;
return out.str();
};
for (auto& device : devices_matched) {
std::string ver = GetOpenCLVersion(device);
int opencl_version = std::stod(ver) * 100;
if (opencl_version >= CL_TARGET_OPENCL_VERSION) {
supported_devices.push_back(device);
} else {
std::string dev_msg = GetDeviceInfo(device, CL_DEVICE_NAME) +
" has OpenCL version == " + get_version_str(opencl_version);
LOG(WARNING) << "TVM supports devices with OpenCL version >= "
<< get_version_str(CL_TARGET_OPENCL_VERSION) << ", device " << dev_msg
<< ". This device will be ignored.";

if (noDevicesErrorMsg.empty()) {
noDevicesErrorMsg =
"Probably this error happen because TVM supports devices with OpenCL version >= " +
get_version_str(CL_TARGET_OPENCL_VERSION) + ". We found the following devices:\n";
}
noDevicesErrorMsg += "\t" + dev_msg + "\n";
}
noDevicesErrorMsg += "\t" + dev_msg + "\n";
}
if (supported_devices.size()) {
device_map[platform_id] = supported_devices;
}
}
if (supported_devices.size() > 0) {
this->platform_id = platform_id;
this->platform_name = cl::GetPlatformInfo(platform_id, CL_PLATFORM_NAME);
this->device_type = device_type;
this->devices = supported_devices;
break;
}
return device_map;
};
auto device_map = find_opencl_device(device_type, platform_name);
if ((device_map.size() == 0) && (device_type == "gpu")) {
LOG(WARNING) << "Using CPU OpenCL device";
device_map = find_opencl_device("cpu", "");
}
if (this->platform_id == nullptr) {
if (device_map.empty()) {
LOG(WARNING) << "No OpenCL device";
initialized_ = true;
return;
}
cl_int err_code;
this->context = clCreateContext(nullptr, this->devices.size(), &(this->devices[0]), nullptr,
nullptr, &err_code);
OPENCL_CHECK_ERROR(err_code);
ICHECK_EQ(this->queues.size(), 0U);
for (size_t i = 0; i < this->devices.size(); ++i) {
cl_device_id did = this->devices[i];
this->queues.push_back(clCreateCommandQueue(this->context, did, 0, &err_code));
cl_int err_code;
for (auto& [platform, devices] : device_map) {
this->platform_ids.push_back(platform);
this->contexts[platform] =
clCreateContext(nullptr, devices.size(), &(devices[0]), nullptr, nullptr, &err_code);
this->devices.insert(this->devices.end(), devices.begin(), devices.end());
for (size_t i = 0; i < devices.size(); ++i) {
cl_device_id did = devices[i];
device_to_platform[did] = platform;
this->queues.push_back(clCreateCommandQueue(this->contexts[platform], did, 0, &err_code));
OPENCL_CHECK_ERROR(err_code);
}
OPENCL_CHECK_ERROR(err_code);
}
this->events.resize(this->devices.size());
Expand Down
17 changes: 11 additions & 6 deletions src/runtime/opencl/opencl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class OpenCLWrappedFunc {
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
ICHECK(w_->context != nullptr) << "No OpenCL device";
ICHECK(w_->devices.size() > 0) << "No OpenCL device";
cl::OpenCLThreadEntry* t = w_->GetThreadEntry();
// get the kernel from thread local kernel table.
if (entry_.kernel_id >= t->kernel_table.size()) {
Expand Down Expand Up @@ -227,21 +227,24 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre
const std::string& func_name, const KTRefEntry& e) {
std::lock_guard<std::mutex> lock(build_lock_);
int device_id = t->device.device_id;
auto did = w->GetCLDeviceID(device_id);
auto platform = w->device_to_platform[did];
if (programs_[func_name][device_id] == nullptr) {
// create program
if (fmt_ == "cl") {
const char* s = parsed_kernels_[func_name].c_str();
size_t len = parsed_kernels_[func_name].length();
cl_int err;
programs_[func_name][device_id] = clCreateProgramWithSource(w->context, 1, &s, &len, &err);
programs_[func_name][device_id] =
clCreateProgramWithSource(w->contexts[platform], 1, &s, &len, &err);
OPENCL_CHECK_ERROR(err);
} else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx") {
const unsigned char* s = (const unsigned char*)data_.c_str();
size_t len = data_.length();
cl_int err;
cl_device_id dev = w->devices[device_id];
programs_[func_name][device_id] =
clCreateProgramWithBinary(w->context, 1, &dev, &len, &s, nullptr, &err);
clCreateProgramWithBinary(w->contexts[platform], 1, &dev, &len, &s, nullptr, &err);
OPENCL_CHECK_ERROR(err);
} else {
LOG(FATAL) << "Unknown OpenCL format " << fmt_;
Expand Down Expand Up @@ -290,9 +293,11 @@ void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) {
size_t binarySize = bin_vector.size();
const unsigned char* programBinary = bin_vector.data();

cl_device_id dev = workspace_->devices[device_id];
programs_[name][device_id] = clCreateProgramWithBinary(
workspace_->context, 1, &dev, &binarySize, &programBinary, &binaryStatus, &err);
cl_device_id dev = workspace_->GetCLDeviceID(device_id);
auto platform = workspace_->device_to_platform[dev];
programs_[name][device_id] =
clCreateProgramWithBinary(workspace_->contexts[platform], 1, &dev, &binarySize,
&programBinary, &binaryStatus, &err);
OPENCL_CHECK_ERROR(err);
OPENCL_CHECK_ERROR(binaryStatus);

Expand Down
11 changes: 6 additions & 5 deletions tests/cpp-runtime/opencl/opencl_timer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,23 @@ using namespace tvm::runtime::cl;
TEST(OpenCLTimerNode, nested_timers) {
OpenCLWorkspace* workspace = OpenCLWorkspace::Global();
OpenCLThreadEntry* thr = workspace->GetThreadEntry();
cl_command_queue queue = workspace->GetQueue(thr->device);

int err;
cl_int* tmp_buf = new cl_int[BUFF_SIZE];
int64_t nested_time_sum = 0;

auto did = workspace->GetCLDeviceID(thr->device.device_id);
auto platform = workspace->device_to_platform[did];
Timer init_timer = Timer::Start(thr->device);
for (int i = 0; i < NUM_REPEAT; ++i) {
Timer nested_timer = Timer::Start(thr->device);
// create some events
cl_event ev = clCreateUserEvent(workspace->context, &err);
cl_event ev = clCreateUserEvent(workspace->contexts[platform], &err);
OPENCL_CHECK_ERROR(err);
cl_mem cl_buf = clCreateBuffer(workspace->context, CL_MEM_READ_ONLY, BUFF_SIZE * sizeof(cl_int),
nullptr, &err);
cl_mem cl_buf = clCreateBuffer(workspace->contexts[platform], CL_MEM_READ_ONLY,
BUFF_SIZE * sizeof(cl_int), nullptr, &err);
OPENCL_CHECK_ERROR(err);
queue = workspace->GetQueue(thr->device);
auto queue = workspace->GetQueue(thr->device);
OPENCL_CALL(clEnqueueWriteBuffer(queue, cl_buf, false, 0, BUFF_SIZE * sizeof(cl_int), tmp_buf,
0, nullptr, &ev));
OPENCL_CALL(clReleaseMemObject(cl_buf));
Expand Down