Skip to content

Commit

Permalink
Add run time compilation
Browse files Browse the repository at this point in the history
- Adds a CompiledModule abstraction on top of Cuda run time compilation.

- Adds a cache of run time compiled kernels. The cache returns a
  kernel immediately and leaves the kernel compiling in the
  background. The kernel's methods wait for the compilation to be
  ready.

- tests that runtime API and driver API streams are interchangeable
  when running a dynamically generated kernel.

- Add proper use of contexts, one per device. The contexts are needed
  because of using the driver API to handle run time compilation.

- Add device properties to the Device* struct.
  • Loading branch information
Orri Erling committed Oct 10, 2024
1 parent 0758d04 commit a4e7b81
Show file tree
Hide file tree
Showing 10 changed files with 677 additions and 20 deletions.
2 changes: 2 additions & 0 deletions velox/experimental/wave/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ velox_add_library(
velox_wave_common
GpuArena.cpp
Buffer.cpp
Compile.cu
Cuda.cu
Exception.cpp
KernelCache.cpp
Type.cpp
ResultStaging.cpp)

Expand Down
165 changes: 165 additions & 0 deletions velox/experimental/wave/common/Compile.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <fmt/format.h>
#include <gflags/gflags.h>
#include <nvrtc.h>
#include "velox/experimental/wave/common/Cuda.h"
#include "velox/experimental/wave/common/CudaUtil.cuh"
#include "velox/experimental/wave/common/Exception.h"

DEFINE_string(
wavegen_architecture,
"compute_80",
"--gpu-architecture flag for generated code");

namespace facebook::velox::wave {

void nvrtcCheck(nvrtcResult result) {
if (result != NVRTC_SUCCESS) {
waveError(nvrtcGetErrorString(result));
}
}

class CompiledModuleImpl : public CompiledModule {
public:
CompiledModuleImpl(CUmodule module, std::vector<CUfunction> kernels)
: module_(module), kernels_(std::move(kernels)) {}

~CompiledModuleImpl() {
auto result = cuModuleUnload(module_);
if (result != CUDA_SUCCESS) {
LOG(ERROR) << "Error in unloading module " << result;
}
}

void launch(
int32_t kernelIdx,
int32_t numBlocks,
int32_t numThreads,
int32_t shared,
Stream* stream,
void** args) override;

KernelInfo info(int32_t kernelIdx) override;

private:
CUmodule module_;
std::vector<CUfunction> kernels_;
};

std::shared_ptr<CompiledModule> CompiledModule::create(const KernelSpec& spec) {
nvrtcProgram prog;
nvrtcCreateProgram(
&prog,
spec.code.c_str(), // buffer
spec.filePath.c_str(), // name
spec.numHeaders, // numHeaders
spec.headers, // headers
spec.headerNames); // includeNames
for (auto& name : spec.entryPoints) {
nvrtcCheck(nvrtcAddNameExpression(prog, name.c_str()));
}
auto architecture =
fmt::format("--gpu-architecture={}", FLAGS_wavegen_architecture);
const char* opts[] = {
architecture.c_str(),
#ifndef NDEBUG
"-G"
#else
"-O3"
#endif
};
auto compileResult = nvrtcCompileProgram(
prog, // prog
sizeof(opts) / sizeof(char*), // numOptions
opts); // options

size_t logSize;

nvrtcGetProgramLogSize(prog, &logSize);
std::string log;
log.resize(logSize);
nvrtcGetProgramLog(prog, log.data());

if (compileResult != NVRTC_SUCCESS) {
nvrtcDestroyProgram(&prog);
waveError(std::string("Cuda compilation error: ") + log);
}
// Obtain PTX from the program.
size_t ptxSize;
nvrtcCheck(nvrtcGetPTXSize(prog, &ptxSize));
std::string ptx;
ptx.resize(ptxSize);
nvrtcCheck(nvrtcGetPTX(prog, ptx.data()));
std::vector<std::string> loweredNames;
for (auto& entry : spec.entryPoints) {
const char* temp;
nvrtcCheck(nvrtcGetLoweredName(prog, entry.c_str(), &temp));
loweredNames.push_back(std::string(temp));
}

nvrtcDestroyProgram(&prog);

CUmodule module;
CU_CHECK(cuModuleLoadDataEx(&module, ptx.data(), 0, 0, 0));
std::vector<CUfunction> funcs;
for (auto& name : loweredNames) {
funcs.emplace_back();
CU_CHECK(cuModuleGetFunction(&funcs.back(), module, name.c_str()));
}
return std::make_shared<CompiledModuleImpl>(module, std::move(funcs));
}

void CompiledModuleImpl::launch(
int32_t kernelIdx,
int32_t numBlocks,
int32_t numThreads,
int32_t shared,
Stream* stream,
void** args) {
auto result = cuLaunchKernel(
kernels_[kernelIdx],
numBlocks,
1,
1, // grid dim
numThreads,
1,
1, // block dim
shared,
reinterpret_cast<CUstream>(stream->stream()->stream),
args,
0);
CU_CHECK(result);
};

KernelInfo CompiledModuleImpl::info(int32_t kernelIdx) {
KernelInfo info;
auto f = kernels_[kernelIdx];
cuFuncGetAttribute(&info.numRegs, CU_FUNC_ATTRIBUTE_NUM_REGS, f);
cuFuncGetAttribute(
&info.sharedMemory, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, f);
cuFuncGetAttribute(
&info.maxThreadsPerBlock, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, f);
int32_t max;
cuOccupancyMaxActiveBlocksPerMultiprocessor(&max, f, 256, 0);
info.maxOccupancy0 = max;
cuOccupancyMaxActiveBlocksPerMultiprocessor(&max, f, 256, 256 * 32);
info.maxOccupancy32 = max;
return info;
}

} // namespace facebook::velox::wave
108 changes: 98 additions & 10 deletions velox/experimental/wave/common/Cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,19 @@
#include "velox/experimental/wave/common/CudaUtil.cuh"
#include "velox/experimental/wave/common/Exception.h"

#include <mutex>
#include <sstream>

namespace facebook::velox::wave {

void cuCheck(CUresult result, const char* file, int32_t line) {
if (result != CUDA_SUCCESS) {
const char* str;
cuGetErrorString(result, &str);
waveError(fmt::format("Cuda error: {}:{} {}", file, line, str));
}
}

void cudaCheck(cudaError_t err, const char* file, int line) {
if (err == cudaSuccess) {
return;
Expand All @@ -43,6 +52,91 @@ void cudaCheckFatal(cudaError_t err, const char* file, int line) {
exit(1);
}

namespace {
std::mutex ctxMutex;
bool driverInited = false;

// A context for each device. Each is initialized on first use and made the
// primary context for the device.
std::vector<CUcontext> contexts;
// Device structs to 1:1 to contexts.
std::vector<std::unique_ptr<Device>> devices;

Device* setDriverDevice(int32_t deviceId) {
if (!driverInited) {
std::lock_guard<std::mutex> l(ctxMutex);
CU_CHECK(cuInit(0));
int32_t cnt;
CU_CHECK(cuDeviceGetCount(&cnt));
contexts.resize(cnt);
devices.resize(cnt);
if (cnt == 0) {
waveError("No Cuda devices found");
}
}
if (deviceId >= contexts.size()) {
waveError(fmt::format("Bad device id {}", deviceId));
}
if (contexts[deviceId] != nullptr) {
cuCtxSetCurrent(contexts[deviceId]);
return devices[deviceId].get();
}
{
std::lock_guard<std::mutex> l(ctxMutex);
CUdevice dev;
CU_CHECK(cuDeviceGet(&dev, deviceId));
CU_CHECK(cuDevicePrimaryCtxRetain(&contexts[deviceId], dev));
devices[deviceId] = std::make_unique<Device>(deviceId);
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceId));
auto& device = devices[deviceId];
device->model = prop.name;
device->major = prop.major;
device->minor = prop.minor;
device->globalMB = prop.totalGlobalMem >> 20;
device->numSM = prop.multiProcessorCount;
device->sharedMemPerSM = prop.sharedMemPerMultiprocessor;
device->L2Size = prop.l2CacheSize;
device->persistingL2MaxSize = prop.persistingL2CacheMaxSize;
}
CU_CHECK(cuCtxSetCurrent(contexts[deviceId]));
return devices[deviceId].get();
}

} // namespace

Device* currentDevice() {
CUcontext ctx;
CU_CHECK(cuCtxGetCurrent(&ctx));
if (!ctx) {
return nullptr;
}
for (auto i = 0; i < contexts.size(); ++i) {
if (contexts[i] == ctx) {
return devices[i].get();
}
}
waveError("Device context not found. Inconsistent state.");
return nullptr;
}

Device* getDevice(int32_t deviceId) {
Device* save = nullptr;
if (driverInited) {
save = currentDevice();
}
auto* dev = setDriverDevice(deviceId);
if (save) {
setDevice(save);
}
return dev;
}

void setDevice(Device* device) {
setDriverDevice(device->deviceId);
CUDA_CHECK(cudaSetDevice(device->deviceId));
}

namespace {
class CudaManagedAllocator : public GpuAllocator {
public:
Expand Down Expand Up @@ -106,23 +200,17 @@ GpuAllocator* getHostAllocator(Device* /*device*/) {
return allocator;
}

// Always returns device 0.
Device* getDevice(int32_t /*preferredDevice*/) {
static Device device(0);
return &device;
}

void setDevice(Device* device) {
CUDA_CHECK(cudaSetDevice(device->deviceId));
}
Stream::Stream(std::unique_ptr<StreamImpl> impl) : stream_(std::move(impl)) {}

Stream::Stream() {
stream_ = std::make_unique<StreamImpl>();
CUDA_CHECK(cudaStreamCreate(&stream_->stream));
}

Stream::~Stream() {
cudaStreamDestroy(stream_->stream);
if (stream_->stream) {
cudaStreamDestroy(stream_->stream);
}
}

void Stream::wait() {
Expand Down
Loading

0 comments on commit a4e7b81

Please sign in to comment.