Skip to content

Commit

Permalink
Merge pull request #35 from shady-gang/cuda-improvements
Browse files Browse the repository at this point in the history
improved CUDA module loading
  • Loading branch information
Hugobros3 authored Apr 16, 2024
2 parents dfea371 + 1962f52 commit 20827a2
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 7 deletions.
2 changes: 2 additions & 0 deletions src/runtime/cuda/cuda_runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ static CudaDevice* create_cuda_device(CudaBackend* b, int ordinal) {
.specialized_programs = new_dict(SpecProgramKey, CudaKernel*, (HashFn) hash_spec_program_key, (CmpFn) cmp_spec_program_keys),
};
CHECK_CUDA(cuDeviceGetName(device->name, 255, handle), goto dealloc_and_return_null);
CHECK_CUDA(cuDeviceGetAttribute(&device->cc_major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device->handle), goto dealloc_and_return_null);
CHECK_CUDA(cuDeviceGetAttribute(&device->cc_minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device->handle), goto dealloc_and_return_null);
CHECK_CUDA(cuCtxCreate(&device->context, 0, handle), goto dealloc_and_return_null);
return device;

Expand Down
2 changes: 2 additions & 0 deletions src/runtime/cuda/cuda_runtime_private.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ typedef struct {
CUdevice handle;
CUcontext context;
char name[256];
int cc_major;
int cc_minor;
struct Dict* specialized_programs;
} CudaDevice;

Expand Down
71 changes: 64 additions & 7 deletions src/runtime/cuda/cuda_runtime_program.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,34 @@ static bool emit_cuda_c_code(CudaKernel* spec) {
Module* final_mod;
emit_c(config, emitter_config, dst_mod, &spec->cuda_code_size, &spec->cuda_code, &final_mod);
spec->final_module = final_mod;

if (get_log_level() <= DEBUG)
write_file("cuda_dump.cu", spec->cuda_code_size - 1, spec->cuda_code);

return true;
}

static bool cuda_c_to_ptx(CudaKernel* kernel) {
nvrtcProgram program;
CHECK_NVRTC(nvrtcCreateProgram(&program, kernel->cuda_code, kernel->key.entry_point, 0, NULL, NULL), return false);
const char* args[] = { "--use_fast_math" };
nvrtcResult compile_result = nvrtcCompileProgram(program, sizeof(args) / sizeof(*args), args);

assert(kernel->device->cc_major < 10 && kernel->device->cc_minor < 10);

char arch_flag[] = "-arch=compute_00";
arch_flag[14] = '0' + kernel->device->cc_major;
arch_flag[15] = '0' + kernel->device->cc_minor;

const char* options[] = {
arch_flag,
"--use_fast_math"
};

nvrtcResult compile_result = nvrtcCompileProgram(program, sizeof(options)/sizeof(*options), options);
if (compile_result != NVRTC_SUCCESS) {
error_print("NVRTC compilation failed: %s\n", nvrtcGetErrorString(compile_result));
debug_print("Dumping source:\n%s", kernel->cuda_code);
}

if (get_log_level() <= DEBUG)
write_file("cuda_dump.cu", kernel->cuda_code_size - 1, kernel->cuda_code);

size_t log_size;
CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size), return false);
char* log_buffer = calloc(log_size, 1);
Expand All @@ -61,13 +73,58 @@ static bool cuda_c_to_ptx(CudaKernel* kernel) {
read_file(override_file, &kernel->ptx_size, &kernel->ptx);
}

if (get_log_level() <= DEBUG)
write_file("cuda_dump.ptx", kernel->ptx_size - 1, kernel->ptx);

return true;
}

static bool load_ptx_into_cuda_program(CudaKernel* kernel) {
CHECK_CUDA(cuModuleLoadDataEx(&kernel->cuda_module, kernel->ptx, 0, NULL, NULL), return false);
CHECK_CUDA(cuModuleGetFunction(&kernel->entry_point_function, kernel->cuda_module, kernel->key.entry_point), return false);
char info_log[10240] = {};
char error_log[10240] = {};

CUjit_option options[] = {
CU_JIT_INFO_LOG_BUFFER, CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
CU_JIT_TARGET
};

void* option_values[] = {
info_log, (void*)(uintptr_t)sizeof(info_log),
error_log, (void*)(uintptr_t)sizeof(error_log),
(void*)(uintptr_t)(kernel->device->cc_major * 10 + kernel->device->cc_minor)
};

CUlinkState linker;
CHECK_CUDA(cuLinkCreate(sizeof(options)/sizeof(options[0]), options, option_values, &linker), goto err_linker_create);
CHECK_CUDA(cuLinkAddData(linker, CU_JIT_INPUT_PTX, kernel->ptx, kernel->ptx_size, NULL, 0U, NULL, NULL), goto err_post_linker_create);

void* binary;
size_t binary_size;
CHECK_CUDA(cuLinkComplete(linker, &binary, &binary_size), goto err_post_linker_create);

if (*info_log)
info_print("CUDA JIT info: %s\n", info_log);

if (get_log_level() <= DEBUG)
write_file("cuda_dump.cubin", binary_size, binary);

CHECK_CUDA(cuModuleLoadData(&kernel->cuda_module, binary), goto err_post_linker_create);
CHECK_CUDA(cuModuleGetFunction(&kernel->entry_point_function, kernel->cuda_module, kernel->key.entry_point), goto err_post_module_load);

cuLinkDestroy(linker);
return true;

err_post_module_load:
cuModuleUnload(kernel->cuda_module);
err_post_linker_create:
cuLinkDestroy(linker);
if (*info_log)
info_print("CUDA JIT info: %s\n", info_log);
if (*error_log)
error_print("CUDA JIT failed: %s\n", error_log);
err_linker_create:
return false;
}

static CudaKernel* create_specialized_program(CudaDevice* device, SpecProgramKey key) {
Expand Down

0 comments on commit 20827a2

Please sign in to comment.