Skip to content

Commit

Permalink
[CUDA] Initial support for dynamic shared memory (apache#8466)
Browse files Browse the repository at this point in the history
* send dyn shmem size to runtime

* add dyn shared storage scope

* associate buffer var and its storage scoe in split_host_device

* tried NVPTX but failed with INVALID_PTX error

* test stub

* dynamic shmem reduce working

* log2 issue fixed

* nvptx working

* refactor llvm shmem allocation

* make linkage argument

* support rocm too

* send dyn shmem param to hip runtime

* remove alloc map from split_host_device.cc

* remove attr::storage_scope from split_host_device

* lint fix

* formatting

* update calling convention doc

* minor update to test

* remove log

* remove kDynShared, dyn.shared -> shared.dyn

* support backward compat

* update json/binary reader/writer

* thread_axis_tags -> launch_param_tags

* ThreadAxisConfig -> LaunchParamConfig

* remove use_dynamic_shared_memory from FunctionInfo meta data

* revert change in test_tir_ir_builder.py

* make sure kUseDynamicSharedMemoryTag is the last tag

* remove continue

* update doc string following name change

* more comment update following name change

Co-authored-by: masa <masa@pop-os.localdomain>
Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
  • Loading branch information
3 people authored and ylc committed Jan 13, 2022
1 parent a1e708c commit 5972870
Show file tree
Hide file tree
Showing 22 changed files with 277 additions and 141 deletions.
4 changes: 2 additions & 2 deletions docs/dev/codebase_walkthrough.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ The first time you invoke the compiled module with ``fadd(a, b, c)``, ``GetFunct
auto it = fmap_.find(name);
const FunctionInfo& info = it->second;
CUDAWrappedFunc f;
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags);
return PackFuncVoidAddr(f, info.arg_types);
}

Expand All @@ -204,7 +204,7 @@ The ``PackedFunc``'s overloaded ``operator()`` will be called, which in turn cal
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
}
CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
ThreadWorkLoad wl = launch_param_config_.Extract(args);
CUresult result = cuLaunchKernel(
fcache_[device_id],
wl.grid_dim(0),
Expand Down
11 changes: 10 additions & 1 deletion include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,17 +240,26 @@ namespace attr {
*
* Call(f,
* [arg1, arg2, ..., arg_n,
* work_size_1, work_size_2, ... work_size_m])
* work_size_1, work_size_2, ... work_size_m, dyn_shmem_size])
*
* Here n = len(arg), m = len(work_size) = len(device_thread_axis).
*
* When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is omitted.
*
* The list of device_thread_axis indicates how can be bind the
* work_size arguments to the corresponding threads.
*
* \sa tvm::CallingConv::kDeviceKernelLaunch
*/
constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis";

/*!
* \brief Whether or not use dynamic shared memory.
*
* Type: Integer
*/
constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory";

/*!
* \brief Whether to set noalias rule on the function arguments.
*
Expand Down
14 changes: 7 additions & 7 deletions src/runtime/cuda/cuda_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,12 @@ class CUDAWrappedFunc {
public:
// initialize the CUDA function.
void Init(CUDAModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
size_t num_void_args, const std::vector<std::string>& thread_axis_tags) {
size_t num_void_args, const std::vector<std::string>& launch_param_tags) {
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
std::fill(fcache_.begin(), fcache_.end(), nullptr);
thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
launch_param_config_.Init(num_void_args, launch_param_tags);
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
Expand All @@ -168,10 +168,10 @@ class CUDAWrappedFunc {
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
}
CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
ThreadWorkLoad wl = launch_param_config_.Extract(args);
CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1),
wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1),
wl.block_dim(2), 0, strm, void_args, nullptr);
wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr);
if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) {
const char* msg;
cuGetErrorName(result, &msg);
Expand Down Expand Up @@ -201,8 +201,8 @@ class CUDAWrappedFunc {
// Device function cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<CUfunction, kMaxNumGPUs> fcache_;
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
// launch parameters configuration
LaunchParamConfig launch_param_config_;
};

class CUDAPrepGlobalBarrier {
Expand Down Expand Up @@ -241,7 +241,7 @@ PackedFunc CUDAModuleNode::GetFunction(const std::string& name,
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
CUDAWrappedFunc f;
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags);
return PackFuncVoidAddr(f, info.arg_types);
}

Expand Down
10 changes: 6 additions & 4 deletions src/runtime/file_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("arg_types", sarg_types);
writer->WriteObjectKeyValue("thread_axis_tags", thread_axis_tags);
writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags);
writer->EndObject();
}

Expand All @@ -52,7 +52,9 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
std::vector<std::string> sarg_types;
helper.DeclareField("name", &name);
helper.DeclareField("arg_types", &sarg_types);
helper.DeclareField("thread_axis_tags", &thread_axis_tags);
helper.DeclareOptionalField("launch_param_tags", &launch_param_tags);
helper.DeclareOptionalField("thread_axis_tags",
&launch_param_tags); // for backward compatibility
helper.ReadAllFields(reader);
arg_types.resize(sarg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
Expand All @@ -63,13 +65,13 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
void FunctionInfo::Save(dmlc::Stream* writer) const {
writer->Write(name);
writer->Write(arg_types);
writer->Write(thread_axis_tags);
writer->Write(launch_param_tags);
}

bool FunctionInfo::Load(dmlc::Stream* reader) {
if (!reader->Read(&name)) return false;
if (!reader->Read(&arg_types)) return false;
if (!reader->Read(&thread_axis_tags)) return false;
if (!reader->Read(&launch_param_tags)) return false;
return true;
}

Expand Down
5 changes: 4 additions & 1 deletion src/runtime/meta_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,14 @@ Module MetadataModuleCreate(
const std::unordered_map<std::string, NDArray>& metadata,
const std::unordered_map<std::string, std::vector<std::string>>& sym_vars);

/*! \brief A tag to specify whether or not dynamic shared memory is used */
constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory";

/*! \brief function information needed by device */
struct FunctionInfo {
std::string name;
std::vector<DLDataType> arg_types;
std::vector<std::string> thread_axis_tags;
std::vector<std::string> launch_param_tags;

void Save(dmlc::JSONWriter* writer) const;
void Load(dmlc::JSONReader* reader);
Expand Down
12 changes: 6 additions & 6 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,15 @@ void SaveToBinary(dmlc::Stream* stream) final {
// initialize the METAL function.
void Init(MetalModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
size_t num_buffer_args, size_t num_pack_args,
const std::vector<std::string>& thread_axis_tags) {
const std::vector<std::string>& launch_param_tags) {
w_ = metal::MetalWorkspace::Global();
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
num_buffer_args_ = num_buffer_args;
num_pack_args_ = num_pack_args;
std::fill(scache_.begin(), scache_.end(), (id<MTLComputePipelineState>)nil);
thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
launch_param_config_.Init(num_buffer_args + num_pack_args, launch_param_tags);
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
int dev_id = t->device.device_id;
scache_[dev_id] = m->GetPipelineState(dev_id, func_name);
Expand All @@ -201,7 +201,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
if (scache_[device_id] == nil) {
scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
}
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
ThreadWorkLoad wl = launch_param_config_.Extract(args);
int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup;
CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
Expand Down Expand Up @@ -242,8 +242,8 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
// Device state cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<id<MTLComputePipelineState>, kMetalMaxNumDevice> scache_;
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
// launch parameters configuration
LaunchParamConfig launch_param_config_;
};

PackedFunc MetalModuleNode::GetFunction(const std::string& name,
Expand All @@ -261,7 +261,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
MetalWrappedFunc f;
size_t num_buffer_args = NumBufferArgs(info.arg_types);
f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args,
info.thread_axis_tags);
info.launch_param_tags);
pf = PackFuncNonBufferArg(f, info.arg_types);
};
return pf;
Expand Down
14 changes: 7 additions & 7 deletions src/runtime/opencl/opencl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ class OpenCLWrappedFunc {
// initialize the OpenCL function.
void Init(OpenCLModuleNode* m, ObjectPtr<Object> sptr, OpenCLModuleNode::KTRefEntry entry,
std::string func_name, std::vector<size_t> arg_size,
const std::vector<std::string>& thread_axis_tags) {
const std::vector<std::string>& launch_param_tags) {
w_ = m->GetGlobalWorkspace();
m_ = m;
sptr_ = sptr;
entry_ = entry;
func_name_ = func_name;
arg_size_ = arg_size;
thread_axis_cfg_.Init(arg_size.size(), thread_axis_tags);
launch_param_config_.Init(arg_size.size(), launch_param_tags);
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
Expand All @@ -73,8 +73,8 @@ class OpenCLWrappedFunc {
OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], arg));
}
cl_command_queue queue = w_->GetQueue(t->device);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
cl_uint work_dim = static_cast<cl_uint>(thread_axis_cfg_.work_dim());
ThreadWorkLoad wl = launch_param_config_.Extract(args);
cl_uint work_dim = static_cast<cl_uint>(launch_param_config_.work_dim());
for (cl_uint i = 0; i < work_dim; ++i) {
wl.work_size[i] *= wl.work_size[i + 3];
}
Expand All @@ -96,8 +96,8 @@ class OpenCLWrappedFunc {
std::string func_name_;
// convert code for void argument
std::vector<size_t> arg_size_;
// thread axis config
ThreadAxisConfig thread_axis_cfg_;
// launch parameters config
LaunchParamConfig launch_param_config_;
};

OpenCLModuleNode::~OpenCLModuleNode() {
Expand Down Expand Up @@ -148,7 +148,7 @@ PackedFunc OpenCLModuleNode::GetFunction(const std::string& name,
}
}
// initialize the wrapped func.
f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.thread_axis_tags);
f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.launch_param_tags);
return PackFuncVoidAddr(f, info.arg_types);
}

Expand Down
19 changes: 10 additions & 9 deletions src/runtime/rocm/rocm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,12 @@ class ROCMWrappedFunc {
public:
// initialize the ROCM function.
void Init(ROCMModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
size_t num_void_args, const std::vector<std::string>& thread_axis_tags) {
size_t num_void_args, const std::vector<std::string>& launch_param_tags) {
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
std::fill(fcache_.begin(), fcache_.end(), nullptr);
thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
launch_param_config_.Init(num_void_args, launch_param_tags);
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const {
Expand All @@ -164,13 +164,14 @@ class ROCMWrappedFunc {

hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);

ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
ThreadWorkLoad wl = launch_param_config_.Extract(args);
void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE,
&packed_nbytes, HIP_LAUNCH_PARAM_END};
// HIP supports only extra_args.
ROCM_DRIVER_CALL(hipModuleLaunchKernel(
fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0),
wl.block_dim(1), wl.block_dim(2), 0, strm, nullptr, reinterpret_cast<void**>(&config)));
ROCM_DRIVER_CALL(hipModuleLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1),
wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1),
wl.block_dim(2), wl.dyn_shmem_size, strm, nullptr,
reinterpret_cast<void**>(&config)));
}

private:
Expand All @@ -183,8 +184,8 @@ class ROCMWrappedFunc {
// Device function cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<hipFunction_t, kMaxNumGPUs> fcache_;
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
// launch parameters configuration
LaunchParamConfig launch_param_config_;
};

PackedFunc ROCMModuleNode::GetFunction(const std::string& name,
Expand All @@ -195,7 +196,7 @@ PackedFunc ROCMModuleNode::GetFunction(const std::string& name,
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
ROCMWrappedFunc f;
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags);
return PackFuncPackedArg(f, info.arg_types);
}

Expand Down
33 changes: 24 additions & 9 deletions src/runtime/thread_storage_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

/*!
* \file thread_storage_scope.h
* \brief Extract thread axis configuration from TVMArgs.
* \brief Extract launch parameters configuration from TVMArgs.
*/
#ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
#define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
Expand All @@ -29,6 +29,8 @@
#include <string>
#include <vector>

#include "meta_data.h"

namespace tvm {
namespace runtime {

Expand Down Expand Up @@ -182,6 +184,8 @@ struct ThreadScope {
struct ThreadWorkLoad {
// array, first three are thread configuration.
size_t work_size[6];
// Dynamic shared memory allocation size in bytes.
size_t dyn_shmem_size{0};
/*!
* \param i The block dimension.
* \return i-th block dim
Expand All @@ -193,17 +197,23 @@ struct ThreadWorkLoad {
*/
inline size_t grid_dim(size_t i) const { return work_size[i]; }
};
/*! \brief Thread axis configuration */
class ThreadAxisConfig {
/*! \brief Launch parameters configuration */
class LaunchParamConfig {
public:
void Init(size_t base, const std::vector<std::string>& thread_axis_tags) {
void Init(size_t base, const std::vector<std::string>& launch_param_tags) {
base_ = base;
std::vector<bool> filled(6, false);
for (size_t i = 0; i < thread_axis_tags.size(); ++i) {
const std::string& tag = thread_axis_tags[i];
ThreadScope ts = ThreadScope::Create(tag);
arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
filled[ts.rank * 3 + ts.dim_index] = true;
for (size_t i = 0; i < launch_param_tags.size(); ++i) {
const std::string& tag = launch_param_tags[i];
if (tag == kUseDynamicSharedMemoryTag) {
ICHECK_EQ(i, launch_param_tags.size() - 1)
<< "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags.";
use_dyn_shared_memory_ = true;
} else {
ThreadScope ts = ThreadScope::Create(tag);
arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
filled[ts.rank * 3 + ts.dim_index] = true;
}
}
work_dim_ = 1;
for (int i = 0; i < 3; ++i) {
Expand All @@ -223,6 +233,9 @@ class ThreadAxisConfig {
w.work_size[arg_index_map_[i]] = size;
}
}
if (use_dyn_shared_memory_) {
w.dyn_shmem_size = static_cast<size_t>(x.values[base_ + arg_index_map_.size()].v_int64);
}
return w;
}
// return the work dim
Expand All @@ -235,6 +248,8 @@ class ThreadAxisConfig {
size_t work_dim_;
/*! \brief The index mapping. */
std::vector<uint32_t> arg_index_map_;
/*! \brief Whether or not use dynamic shared memory. */
bool use_dyn_shared_memory_{false};
};

} // namespace runtime
Expand Down
Loading

0 comments on commit 5972870

Please sign in to comment.