From f709892217d67671dd9b1b70158f9c89d3bae76e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 21 Jun 2021 09:53:57 +0900 Subject: [PATCH 01/30] send dyn shmem size to runtime --- src/runtime/thread_storage_scope.h | 3 +++ src/tir/transforms/split_host_device.cc | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 9d140aedd810..4bdbbd110301 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -182,6 +182,8 @@ struct ThreadScope { struct ThreadWorkLoad { // array, first three are thread configuration. size_t work_size[6]; + // TODO + size_t dyn_shmem_size{0}; /*! * \param i The block dimension. * \return i-th block dim @@ -223,6 +225,7 @@ class ThreadAxisConfig { w.work_size[arg_index_map_[i]] = size; } } + w.dyn_shmem_size = static_cast(x.values[base_ + arg_index_map_.size()].v_int64); return w; } // return the work dim diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index f01d98707586..49639f9cabea 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -175,6 +175,7 @@ class VarUseDefAnalysis : public StmtExprMutator { Array undefined_; Array thread_axis_; Array thread_extent_; + PrimExpr dyn_shmem_size_{0}; std::unordered_map use_count_; std::unordered_map def_count_; @@ -273,6 +274,8 @@ class HostDeviceSplitter : public StmtMutator { for (PrimExpr ext : m.thread_extent_) { call_args.push_back(ext); } + // dynamic shared memory size + call_args.push_back(m.dyn_shmem_size_); return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args)); } From f20780bcf7a5ddd6fc24f173b20c08b606d11aa6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 21 Jun 2021 10:23:53 +0900 Subject: [PATCH 02/30] add dyn shared storage scope --- src/runtime/cuda/cuda_module.cc | 2 +- src/runtime/thread_storage_scope.h | 17 ++++++++++++----- src/target/source/codegen_cuda.cc | 17 ++++++++++++----- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index a877bc634300..7b4a3accc6c3 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -171,7 +171,7 @@ class CUDAWrappedFunc { ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), - wl.block_dim(2), 0, strm, void_args, nullptr); + wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr); if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { const char* msg; cuGetErrorName(result, &msg); diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 4bdbbd110301..ea0eb4485304 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -42,21 +42,23 @@ enum class StorageRank { kGlobal = 0, /*! \brief shared memory among thread group */ kShared = 1, + /*! \brief dynamic shared memory among thread group */ + kDynShared = 2, /*! * \brief reserved for warp memory. * This is only used by programming model. * There is no such memory usually in GPU. * Instead, we can simulate it by registers and shuffle. */ - kWarp = 2, + kWarp = 3, /*! \brief thread local memory */ - kLocal = 3, + kLocal = 4, /*! \brief wmma scope memory of matrix_a */ - kWMMAMatrixA = 4, + kWMMAMatrixA = 5, /*! \brief wmma scope memory of matrix_b */ - kWMMAMatrixB = 5, + kWMMAMatrixB = 6, /*! \brief wmma scope memory of accumulator */ - kWMMAAccumulator = 6, + kWMMAAccumulator = 7, }; /*! @@ -96,6 +98,8 @@ struct StorageScope { return "global" + tag; case StorageRank::kShared: return "shared" + tag; + case StorageRank::kDynShared: + return "dyn.shared" + tag; case StorageRank::kWarp: return "warp" + tag; case StorageRank::kLocal: @@ -126,6 +130,9 @@ struct StorageScope { } else if (s.compare(0, 6, "shared") == 0) { r.rank = StorageRank::kShared; r.tag = s.substr(6, std::string::npos); + } else if (s.compare(0, 10, "dyn.shared") == 0) { + r.rank = StorageRank::kDynShared; + r.tag = s.substr(10, std::string::npos); } else if (s.compare(0, 4, "warp") == 0) { r.rank = StorageRank::kWarp; r.tag = s.substr(4, std::string::npos); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index d7dcbec7ebe3..7ab1c8ce0848 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -525,6 +525,8 @@ void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) "all global arrays as input instead"; if (scope == "shared") { os << "__shared__ "; + } else if (scope == "dyn.shared") { + os << "extern __shared__ "; } } @@ -726,12 +728,17 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { PrintStorageScope(scope, stream); PrintType(op->dtype, stream); } - if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || - op->dtype == DataType::Int(1)) && - scope == "shared") { - constant_size = constant_size / (32 / op->dtype.bits()); + + if (scope == "dyn.shared") { + stream << ' ' << vid << "[];\n"; + } else { + if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || + op->dtype == DataType::Int(1)) && + scope == "shared") { + constant_size = constant_size / (32 / op->dtype.bits()); + } + stream << ' ' << vid << '[' << constant_size << "];\n"; } - stream << ' ' << vid << '[' << constant_size << "];\n"; RegisterHandleType(op->buffer_var.get(), op->dtype); this->PrintStmt(op->body); From 44ca4bfca7147034ba5377e0e1b46f3238194476 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 21 Jun 2021 10:57:23 +0900 Subject: [PATCH 03/30] associate buffer var and its storage scoe in split_host_device --- src/tir/transforms/split_host_device.cc | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 49639f9cabea..55db6f993d4f 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -33,6 +33,8 @@ #include +#include "../../runtime/thread_storage_scope.h" + namespace tvm { namespace tir { @@ -60,6 +62,12 @@ class VarUseDefAnalysis : public StmtExprMutator { return GetRef(op); } return AttrStmt(op->node, op->attr_key, value, body); + } else if (op->attr_key == attr::storage_scope) { + const VarNode* v = op->node.as(); + ICHECK(v); + auto scope = op->value.as()->value; + alloc_storage_scope_[v] = runtime::StorageScope::Create(scope); + return StmtExprMutator::VisitStmt_(op); } else { return StmtExprMutator::VisitStmt_(op); } @@ -89,6 +97,10 @@ class VarUseDefAnalysis : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { this->HandleDef(op->buffer_var.get()); + if (alloc_storage_scope_[op->buffer_var.get()].rank == runtime::StorageRank::kDynShared) { + // CHECK_EQ(dyn_shmem_size_, PrimExpr(0)) << "Only one dynamic shmem allowed for now"; + dyn_shmem_size_ = op->constant_allocation_size(); + } return StmtExprMutator::VisitStmt_(op); } @@ -176,6 +188,7 @@ class VarUseDefAnalysis : public StmtExprMutator { Array thread_axis_; Array thread_extent_; PrimExpr dyn_shmem_size_{0}; + std::unordered_map alloc_storage_scope_; std::unordered_map use_count_; std::unordered_map def_count_; From d9afa3787bfb788103df59e7ce8e8253180acd1e Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 28 Jun 2021 04:45:28 +0900 Subject: [PATCH 04/30] tried NVPTX but failed with INVALID_PTX error --- src/target/llvm/codegen_nvptx.cc | 67 ++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 25 deletions(-) diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 43ea0e6b7ae9..ea24bb5e0965 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -48,41 +48,20 @@ class CodeGenNVPTX : public CodeGenLLVM { void VisitStmt_(const AllocateNode* op) final { ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - - int32_t constant_size = op->constant_allocation_size(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); - } // maximum necessary alignment in the NV devices if (info.alignment > 16) { info.alignment = 16; } + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - if (storage_scope.rank == runtime::StorageRank::kLocal) { - // const int local_address_space = 5; - // TODO(tqchen): for higher version of LLVM, local address space can be set. - llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); - if (alloca->getAlignment() < static_cast(info.alignment)) { -#if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); -#else - alloca->setAlignment(info.alignment); -#endif - } - buf = alloca; - } else { - ICHECK(storage_scope.rank == runtime::StorageRank::kShared) - << "Can only allocate shared or local memory inside kernel"; + if (storage_scope.rank == runtime::StorageRank::kDynShared) { // Shared memory: address space == 3 const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); + llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), 0); // Allocate shared memory in global, address_space = 3 llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, + *module_, type, false, llvm::GlobalValue::ExternalLinkage, nullptr, ".shared", nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(info.alignment)); @@ -90,6 +69,44 @@ class CodeGenNVPTX : public CodeGenLLVM { global->setAlignment(info.alignment); #endif buf = global; + } else { + int32_t constant_size = op->constant_allocation_size(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + } + if (storage_scope.rank == runtime::StorageRank::kLocal) { + // const int local_address_space = 5; + // TODO(tqchen): for higher version of LLVM, local address space can be set. + llvm::AllocaInst* alloca = WithFunctionEntry([&]() { + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); + if (alloca->getAlignment() < static_cast(info.alignment)) { +#if TVM_LLVM_VERSION >= 100 + alloca->setAlignment(llvm::Align(info.alignment)); +#else + alloca->setAlignment(info.alignment); +#endif + } + buf = alloca; + } else { + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) + << "Can only allocate shared or local memory inside kernel"; + // Shared memory: address space == 3 + const unsigned shared_address_space = 3; + llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); + // Allocate shared memory in global, address_space = 3 + llvm::GlobalVariable* global = new llvm::GlobalVariable( + *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, + llvm::GlobalValue::NotThreadLocal, shared_address_space); +#if TVM_LLVM_VERSION >= 100 + global->setAlignment(llvm::Align(info.alignment)); +#else + global->setAlignment(info.alignment); +#endif + buf = global; + } } buf = builder_->CreatePointerCast( From 2c81667dd46420541f54616c507085dc31de4be6 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 28 Jun 2021 05:23:05 +0900 Subject: [PATCH 05/30] test stub --- tests/python/unittest/test_tir_ir_builder.py | 62 +++++++++++++++++--- 1 file changed, 53 insertions(+), 9 deletions(-) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 355d3abed559..a32bf54f088b 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -497,13 +497,57 @@ def check_target(target, ir): check_target("vulkan", searchsorted_ir_gpu) +@tvm.testing.requires_gpu +def test_dyn_shared(): + n = te.size_var("n") + dtype = "float32" + A = te.placeholder((n,), name="A") + + def test_device_ir(A, C): + n = A.shape[0] + ib = tvm.tir.ir_builder.create() + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", n) + Aptr = ib.buffer_ptr(A) + Cptr = ib.buffer_ptr(C) + Cptr[tx] = Aptr[tx] + 1 + body = ib.get() + return body + + C = te.extern( + A.shape, + [A], + lambda ins, outs: test_device_ir(ins[0], outs[0]), + name="vector_add", + dtype=dtype, + ) + s = te.create_schedule(C.op) + + def check_target(target): + n = 1024 + if not tvm.testing.device_enabled(target): + return + # build and invoke the kernel. + fadd = tvm.build(s, [A, C], target) + dev = tvm.device(target, 0) + # launch the kernel. + for n in [512, 1024]: + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) + fadd(a, c) + tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) + + check_target("cuda") + + if __name__ == "__main__": - test_prefetch() - test_if() - test_for() - test_cpu() - test_gpu() - test_while_vectorize() - test_while_collatz() - test_while_mandel() - test_while_binary_search() + # test_prefetch() + # test_if() + # test_for() + # test_cpu() + # test_gpu() + # test_while_vectorize() + # test_while_collatz() + # test_while_mandel() + # test_while_binary_search() + test_dyn_shared() From 51e87cb1fb2127c90a3a205fd21a0e7222f8c7b3 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 28 Jun 2021 05:54:38 +0900 Subject: [PATCH 06/30] dynamic shmem reduce working --- src/target/source/codegen_cuda.cc | 11 +++--- src/tir/transforms/split_host_device.cc | 6 +++- tests/python/unittest/test_tir_ir_builder.py | 35 ++++++++++++-------- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 7ab1c8ce0848..dc78b26f44d5 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -705,9 +705,8 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { std::string vid = AllocVarID(op->buffer_var.get()); this->PrintIndent(); - int32_t constant_size = op->constant_allocation_size(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; std::string scope = GetPtrStorageScope(op->buffer_var); + const VarNode* buffer = op->buffer_var.as(); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || @@ -721,8 +720,6 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { op->dtype == DataType::Int(32)) << "Accumulator only support half, float and int type for now"; } - const VarNode* buffer = op->buffer_var.as(); - constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); PrintWmmaScope(scope, op->dtype, buffer, stream); } else { PrintStorageScope(scope, stream); @@ -732,6 +729,12 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { if (scope == "dyn.shared") { stream << ' ' << vid << "[];\n"; } else { + int32_t constant_size = op->constant_allocation_size(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + + if (scope.find("wmma.") == 0) { + constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); + } if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1)) && scope == "shared") { diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 55db6f993d4f..76e33812452c 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -99,7 +99,11 @@ class VarUseDefAnalysis : public StmtExprMutator { this->HandleDef(op->buffer_var.get()); if (alloc_storage_scope_[op->buffer_var.get()].rank == runtime::StorageRank::kDynShared) { // CHECK_EQ(dyn_shmem_size_, PrimExpr(0)) << "Only one dynamic shmem allowed for now"; - dyn_shmem_size_ = op->constant_allocation_size(); + dyn_shmem_size_ = op->extents[0]; + for (size_t i = 1; i < op->extents.size(); ++i) { + dyn_shmem_size_ *= op->extents[i]; + } + dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); } return StmtExprMutator::VisitStmt_(op); } diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index a32bf54f088b..bde4bdbf6e75 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -510,32 +510,41 @@ def test_device_ir(A, C): ib.scope_attr(tx, "thread_extent", n) Aptr = ib.buffer_ptr(A) Cptr = ib.buffer_ptr(C) - Cptr[tx] = Aptr[tx] + 1 - body = ib.get() - return body + temp = ib.allocate(dtype, (n,), scope="dyn.shared") + temp[tx] = Aptr[tx] + # depth = tvm.tir.log2(n) + depth = 9 + with ib.for_range(0, depth) as i: + ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + d = n >> (i + 1) + with ib.if_scope(tx < d): + temp[tx] += temp[tx + d] + + Cptr[0] = temp[0] + return ib.get() C = te.extern( - A.shape, + (1,), [A], lambda ins, outs: test_device_ir(ins[0], outs[0]), - name="vector_add", + name="reduce", dtype=dtype, ) s = te.create_schedule(C.op) def check_target(target): - n = 1024 if not tvm.testing.device_enabled(target): return - # build and invoke the kernel. - fadd = tvm.build(s, [A, C], target) + + freduce = tvm.build(s, [A, C], target) dev = tvm.device(target, 0) - # launch the kernel. - for n in [512, 1024]: + + for n in [512]: a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) - fadd(a, c) - tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) + c = tvm.nd.array(np.zeros(1, dtype=C.dtype), dev) + freduce(a, c) + tvm.testing.assert_allclose(c.numpy()[0], np.sum(a.numpy())) + print(c.numpy()[0], np.sum(a.numpy())) check_target("cuda") From 95e1b81d0e23ad8682f4f0c68f14c1ff253f6677 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 28 Jun 2021 10:08:33 +0900 Subject: [PATCH 07/30] log2 issue fixed --- tests/python/unittest/test_tir_ir_builder.py | 26 ++++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index bde4bdbf6e75..576c3d262507 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -18,6 +18,7 @@ from tvm import te import numpy as np import tvm.testing +from tvm.topi.math import cast def test_for(): @@ -503,48 +504,47 @@ def test_dyn_shared(): dtype = "float32" A = te.placeholder((n,), name="A") - def test_device_ir(A, C): + def test_device_ir(A, B): n = A.shape[0] ib = tvm.tir.ir_builder.create() tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", n) Aptr = ib.buffer_ptr(A) - Cptr = ib.buffer_ptr(C) + Bptr = ib.buffer_ptr(B) temp = ib.allocate(dtype, (n,), scope="dyn.shared") temp[tx] = Aptr[tx] - # depth = tvm.tir.log2(n) - depth = 9 + depth = tvm.tir.log2(cast(n, "float32")) + with ib.for_range(0, depth) as i: ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) d = n >> (i + 1) with ib.if_scope(tx < d): temp[tx] += temp[tx + d] - Cptr[0] = temp[0] + Bptr[0] = temp[0] return ib.get() - C = te.extern( + B = te.extern( (1,), [A], lambda ins, outs: test_device_ir(ins[0], outs[0]), name="reduce", dtype=dtype, ) - s = te.create_schedule(C.op) + s = te.create_schedule(B.op) def check_target(target): if not tvm.testing.device_enabled(target): return - freduce = tvm.build(s, [A, C], target) + freduce = tvm.build(s, [A, B], target) dev = tvm.device(target, 0) - for n in [512]: + for n in [512, 1024]: a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - c = tvm.nd.array(np.zeros(1, dtype=C.dtype), dev) - freduce(a, c) - tvm.testing.assert_allclose(c.numpy()[0], np.sum(a.numpy())) - print(c.numpy()[0], np.sum(a.numpy())) + b = tvm.nd.array(np.zeros(1, dtype=B.dtype), dev) + freduce(a, b) + tvm.testing.assert_allclose(b.numpy()[0], np.sum(a.numpy()), 1e-4, 1e-4) check_target("cuda") From 719f40dea834f0896177d76393e2021a95e794e0 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 28 Jun 2021 11:02:57 +0900 Subject: [PATCH 08/30] nvptx working --- src/target/llvm/codegen_nvptx.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index ea24bb5e0965..d7e7643153c1 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -61,7 +61,7 @@ class CodeGenNVPTX : public CodeGenLLVM { llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), 0); // Allocate shared memory in global, address_space = 3 llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::ExternalLinkage, nullptr, ".shared", nullptr, + *module_, type, false, llvm::GlobalValue::ExternalLinkage, nullptr, "buf", nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(info.alignment)); From 6d9c7c42566839beb4e715488c1d17784bfbee16 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 28 Jun 2021 11:57:29 +0900 Subject: [PATCH 09/30] refactor llvm shmem allocation --- src/target/llvm/codegen_llvm.cc | 16 ++++++++++++ src/target/llvm/codegen_llvm.h | 4 +++ src/target/llvm/codegen_nvptx.cc | 27 ++------------------ tests/python/unittest/test_tir_ir_builder.py | 13 +++++++--- 4 files changed, 31 insertions(+), 29 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index bdae93b82aff..2205e564bd9b 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -524,6 +524,22 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp *p_alignment = align_bits / 8; } +llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, + size_t size, + unsigned int shared_address_space, + int alignment) { + llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size); + llvm::GlobalVariable* global = new llvm::GlobalVariable( + *module_, type, false, llvm::GlobalValue::ExternalLinkage, nullptr, "shmem", nullptr, + llvm::GlobalValue::NotThreadLocal, shared_address_space); +#if TVM_LLVM_VERSION >= 100 + global->setAlignment(llvm::Align(alignment)); +#else + global->setAlignment(alignment); +#endif + return global; +} + std::unique_ptr CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { #if TVM_LLVM_VERSION >= 100 auto debug_info = std::make_unique(); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 810e59be7214..9e19a3eea9e8 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -292,6 +292,10 @@ class CodeGenLLVM : public ExprFunctor, const Var& loop_var, const Stmt& body); // add alias information. void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index); + + llvm::GlobalVariable* AllocateSharedMemory(DataType dtype, size_t size, + unsigned int shared_address_space, int alignment); + // The IRBuilder. using IRBuilder = llvm::IRBuilder; // The current function diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index d7e7643153c1..1280acd04410 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -57,18 +57,7 @@ class CodeGenNVPTX : public CodeGenLLVM { auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kDynShared) { // Shared memory: address space == 3 - const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), 0); - // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::ExternalLinkage, nullptr, "buf", nullptr, - llvm::GlobalValue::NotThreadLocal, shared_address_space); -#if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); -#else - global->setAlignment(info.alignment); -#endif - buf = global; + buf = AllocateSharedMemory(op->dtype, 0, 3, info.alignment); } else { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; @@ -93,19 +82,7 @@ class CodeGenNVPTX : public CodeGenLLVM { } else { ICHECK(storage_scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; - // Shared memory: address space == 3 - const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); - // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, - llvm::GlobalValue::NotThreadLocal, shared_address_space); -#if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); -#else - global->setAlignment(info.alignment); -#endif - buf = global; + buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment); } } diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 576c3d262507..a3c60ec3dba8 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -507,11 +507,15 @@ def test_dyn_shared(): def test_device_ir(A, B): n = A.shape[0] ib = tvm.tir.ir_builder.create() + tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", n) + + temp = ib.allocate(dtype, (n,), scope="dyn.shared") + Aptr = ib.buffer_ptr(A) Bptr = ib.buffer_ptr(B) - temp = ib.allocate(dtype, (n,), scope="dyn.shared") + temp[tx] = Aptr[tx] depth = tvm.tir.log2(cast(n, "float32")) @@ -534,8 +538,8 @@ def test_device_ir(A, B): s = te.create_schedule(B.op) def check_target(target): - if not tvm.testing.device_enabled(target): - return + # if not tvm.testing.device_enabled(target): + # return freduce = tvm.build(s, [A, B], target) dev = tvm.device(target, 0) @@ -546,7 +550,8 @@ def check_target(target): freduce(a, b) tvm.testing.assert_allclose(b.numpy()[0], np.sum(a.numpy()), 1e-4, 1e-4) - check_target("cuda") + for target in ["cuda", "nvptx"]: + check_target(target) if __name__ == "__main__": From e0fbac21bc4332a10343ea71ebad5aa1f6acb3d3 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 28 Jun 2021 12:12:50 +0900 Subject: [PATCH 10/30] make linkage argument --- src/target/llvm/codegen_llvm.cc | 5 +++-- src/target/llvm/codegen_llvm.h | 3 ++- src/target/llvm/codegen_nvptx.cc | 6 ++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 2205e564bd9b..3178e7e3749a 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -527,10 +527,11 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, size_t size, unsigned int shared_address_space, - int alignment) { + int alignment, + llvm::GlobalValue::LinkageTypes linkage) { llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size); llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::ExternalLinkage, nullptr, "shmem", nullptr, + *module_, type, false, linkage, nullptr, "shmem", nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(alignment)); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 9e19a3eea9e8..52c5b98a0025 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -294,7 +294,8 @@ class CodeGenLLVM : public ExprFunctor, void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index); llvm::GlobalVariable* AllocateSharedMemory(DataType dtype, size_t size, - unsigned int shared_address_space, int alignment); + unsigned int shared_address_space, int alignment, + llvm::GlobalValue::LinkageTypes linkage); // The IRBuilder. using IRBuilder = llvm::IRBuilder; diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 1280acd04410..bcd5c658684e 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -57,7 +57,8 @@ class CodeGenNVPTX : public CodeGenLLVM { auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kDynShared) { // Shared memory: address space == 3 - buf = AllocateSharedMemory(op->dtype, 0, 3, info.alignment); + buf = + AllocateSharedMemory(op->dtype, 0, 3, info.alignment, llvm::GlobalValue::ExternalLinkage); } else { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; @@ -82,7 +83,8 @@ class CodeGenNVPTX : public CodeGenLLVM { } else { ICHECK(storage_scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; - buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment); + buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment, + llvm::GlobalValue::PrivateLinkage); } } From d78152642a8439e9072fadcfb06cd987dfa8f7a2 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 28 Jun 2021 12:14:18 +0900 Subject: [PATCH 11/30] support rocm too --- src/target/llvm/codegen_amdgpu.cc | 69 ++++++++++++++----------------- 1 file changed, 31 insertions(+), 38 deletions(-) diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 9aec8f4e867b..7e87789cee06 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -72,51 +72,44 @@ class CodeGenAMDGPU : public CodeGenLLVM { void VisitStmt_(const AllocateNode* op) final { ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - - int32_t constant_size = op->constant_allocation_size(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; - StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); - } - // maximum necessary alignment in the AMD devices - if (info.alignment > 16) { - info.alignment = 16; - } auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - if (storage_scope.rank == runtime::StorageRank::kLocal) { - // const int local_address_space = 5; - // TODO(tqchen): for higher version of LLVM, local address space can be set. - llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); - if (alloca->getAlignment() < static_cast(info.alignment)) { -#if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); -#else - alloca->setAlignment(info.alignment); -#endif - } - buf = alloca; + + if (storage_scope.rank == runtime::StorageRank::kDynShared) { + buf = AllocateSharedMemory(op->dtype, 0, 3, std::min(info.alignment, 16), + llvm::GlobalValue::ExternalLinkage); } else { - ICHECK(storage_scope.rank == runtime::StorageRank::kShared) - << "Can only allocate shared or local memory inside kernel"; - // Shared memory: address space == 3 - const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); - // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, - llvm::GlobalValue::NotThreadLocal, shared_address_space); - if (global->getAlignment() < static_cast(info.alignment)) { + int32_t constant_size = op->constant_allocation_size(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + } + // maximum necessary alignment in the AMD devices + if (info.alignment > 16) { + info.alignment = 16; + } + if (storage_scope.rank == runtime::StorageRank::kLocal) { + // const int local_address_space = 5; + // TODO(tqchen): for higher version of LLVM, local address space can be set. + llvm::AllocaInst* alloca = WithFunctionEntry([&]() { + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - global->setAlignment(info.alignment); + alloca->setAlignment(info.alignment); #endif + } + buf = alloca; + } else { + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) + << "Can only allocate shared or local memory inside kernel"; + // Shared memory: address space == 3 + buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment, + llvm::GlobalValue::PrivateLinkage); } - buf = global; } buf = builder_->CreatePointerCast( From 933f9c54cbe2b25ff9c8a354b3afe9f56dfab193 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 28 Jun 2021 12:40:29 +0900 Subject: [PATCH 12/30] send dyn shmem param to hip runtime --- src/runtime/rocm/rocm_module.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 567557c56794..e809ac66d73a 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -167,10 +167,11 @@ class ROCMWrappedFunc { ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE, &packed_nbytes, HIP_LAUNCH_PARAM_END}; + LOG(INFO) << "dynamic shared mem size: " << wl.dyn_shmem_size; // HIP supports only extra_args. ROCM_DRIVER_CALL(hipModuleLaunchKernel( fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), - wl.block_dim(1), wl.block_dim(2), 0, strm, nullptr, reinterpret_cast(&config))); + wl.block_dim(1), wl.block_dim(2), wl.dyn_shmem_size, strm, nullptr, reinterpret_cast(&config))); } private: From 509a8c1c02ffa204d56c91f5239b1d290dc853ec Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 29 Jun 2021 18:56:09 +0900 Subject: [PATCH 13/30] remove alloc map from split_host_device.cc --- src/tir/transforms/split_host_device.cc | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 76e33812452c..bebfa94971b8 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -62,12 +62,6 @@ class VarUseDefAnalysis : public StmtExprMutator { return GetRef(op); } return AttrStmt(op->node, op->attr_key, value, body); - } else if (op->attr_key == attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - auto scope = op->value.as()->value; - alloc_storage_scope_[v] = runtime::StorageScope::Create(scope); - return StmtExprMutator::VisitStmt_(op); } else { return StmtExprMutator::VisitStmt_(op); } @@ -97,7 +91,8 @@ class VarUseDefAnalysis : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { this->HandleDef(op->buffer_var.get()); - if (alloc_storage_scope_[op->buffer_var.get()].rank == runtime::StorageRank::kDynShared) { + auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kDynShared) { // CHECK_EQ(dyn_shmem_size_, PrimExpr(0)) << "Only one dynamic shmem allowed for now"; dyn_shmem_size_ = op->extents[0]; for (size_t i = 1; i < op->extents.size(); ++i) { @@ -192,7 +187,6 @@ class VarUseDefAnalysis : public StmtExprMutator { Array thread_axis_; Array thread_extent_; PrimExpr dyn_shmem_size_{0}; - std::unordered_map alloc_storage_scope_; std::unordered_map use_count_; std::unordered_map def_count_; From 0ea0962996d2baa7471df6c073036b3c9da47a11 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 14 Jul 2021 11:19:42 +0900 Subject: [PATCH 14/30] remove attr::storage_scope from split_host_device --- src/target/llvm/codegen_amdgpu.cc | 1 + src/tir/transforms/split_host_device.cc | 9 ++++++--- tests/python/unittest/test_tir_ir_builder.py | 18 +++++++++--------- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 7e87789cee06..12bedfcdffea 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -76,6 +76,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kDynShared) { + LOG(WARNING) << "Dynamic shared memory support for rocm is experimental."; buf = AllocateSharedMemory(op->dtype, 0, 3, std::min(info.alignment, 16), llvm::GlobalValue::ExternalLinkage); } else { diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index bebfa94971b8..ee1cb2e21f33 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -34,6 +34,7 @@ #include #include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" namespace tvm { namespace tir { @@ -91,14 +92,16 @@ class VarUseDefAnalysis : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { this->HandleDef(op->buffer_var.get()); - auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kDynShared) { - // CHECK_EQ(dyn_shmem_size_, PrimExpr(0)) << "Only one dynamic shmem allowed for now"; + ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; + ICHECK_GT(op->extents.size(), 0); dyn_shmem_size_ = op->extents[0]; for (size_t i = 1; i < op->extents.size(); ++i) { dyn_shmem_size_ *= op->extents[i]; } dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); + use_dyn_shmem_ = true; } return StmtExprMutator::VisitStmt_(op); } @@ -193,6 +196,7 @@ class VarUseDefAnalysis : public StmtExprMutator { private: ExprDeepEqual deep_equal_; std::unordered_map let_binding_; + bool use_dyn_shmem_{false}; }; Array UndefinedVars(const Stmt& stmt, const Array& args) { @@ -285,7 +289,6 @@ class HostDeviceSplitter : public StmtMutator { for (PrimExpr ext : m.thread_extent_) { call_args.push_back(ext); } - // dynamic shared memory size call_args.push_back(m.dyn_shmem_size_); return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args)); } diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index a3c60ec3dba8..e2a277f9f307 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -555,13 +555,13 @@ def check_target(target): if __name__ == "__main__": - # test_prefetch() - # test_if() - # test_for() - # test_cpu() - # test_gpu() - # test_while_vectorize() - # test_while_collatz() - # test_while_mandel() - # test_while_binary_search() + test_prefetch() + test_if() + test_for() + test_cpu() + test_gpu() + test_while_vectorize() + test_while_collatz() + test_while_mandel() + test_while_binary_search() test_dyn_shared() From 84666a43bff0b9638762b26527bb437135a52919 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 14 Jul 2021 11:21:21 +0900 Subject: [PATCH 15/30] lint fix --- src/runtime/rocm/rocm_module.cc | 7 ++++--- src/runtime/thread_storage_scope.h | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index e809ac66d73a..b4356c1fff2e 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -169,9 +169,10 @@ class ROCMWrappedFunc { &packed_nbytes, HIP_LAUNCH_PARAM_END}; LOG(INFO) << "dynamic shared mem size: " << wl.dyn_shmem_size; // HIP supports only extra_args. - ROCM_DRIVER_CALL(hipModuleLaunchKernel( - fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), - wl.block_dim(1), wl.block_dim(2), wl.dyn_shmem_size, strm, nullptr, reinterpret_cast(&config))); + ROCM_DRIVER_CALL(hipModuleLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), + wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), + wl.block_dim(2), wl.dyn_shmem_size, strm, nullptr, + reinterpret_cast(&config))); } private: diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index ea0eb4485304..196d2f909ea0 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -189,7 +189,7 @@ struct ThreadScope { struct ThreadWorkLoad { // array, first three are thread configuration. size_t work_size[6]; - // TODO + // Dynamic shared memory allocation size in bytes. size_t dyn_shmem_size{0}; /*! * \param i The block dimension. From 5bcfacddc5fa192bcda34bf9e035cbfacd4b4a0d Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 14 Jul 2021 11:22:24 +0900 Subject: [PATCH 16/30] formatting --- src/target/llvm/codegen_llvm.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 3178e7e3749a..b83748b784b6 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -524,15 +524,14 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp *p_alignment = align_bits / 8; } -llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, - size_t size, +llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, size_t size, unsigned int shared_address_space, int alignment, llvm::GlobalValue::LinkageTypes linkage) { llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size); - llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, false, linkage, nullptr, "shmem", nullptr, - llvm::GlobalValue::NotThreadLocal, shared_address_space); + llvm::GlobalVariable* global = + new llvm::GlobalVariable(*module_, type, false, linkage, nullptr, "shmem", nullptr, + llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(alignment)); #else From 3ac84013942cfeb9e969cbccdf034bf796bad041 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 14 Jul 2021 11:24:38 +0900 Subject: [PATCH 17/30] update calling convention doc --- include/tvm/tir/function.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 25ed2f9ae8d1..2d65814ce34f 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -240,7 +240,7 @@ namespace attr { * * Call(f, * [arg1, arg2, ..., arg_n, - * work_size_1, work_size_2, ... work_size_m]) + * work_size_1, work_size_2, ... work_size_m, dyn_shmem_size]) * * Here n = len(arg), m = len(work_size) = len(device_thread_axis). * From 283e04cd22f08cbf23db3511bb7c253c56545265 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 14 Jul 2021 11:27:27 +0900 Subject: [PATCH 18/30] minor update to test --- tests/python/unittest/test_tir_ir_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index e2a277f9f307..df8b6e997916 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -511,7 +511,7 @@ def test_device_ir(A, B): tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", n) - temp = ib.allocate(dtype, (n,), scope="dyn.shared") + temp = ib.allocate(dtype, (n,), scope="dyn.shared") # n is symbolic size Aptr = ib.buffer_ptr(A) Bptr = ib.buffer_ptr(B) @@ -538,8 +538,8 @@ def test_device_ir(A, B): s = te.create_schedule(B.op) def check_target(target): - # if not tvm.testing.device_enabled(target): - # return + if not tvm.testing.device_enabled(target): + return freduce = tvm.build(s, [A, B], target) dev = tvm.device(target, 0) From 4389ccbdf57b60ed600739b10707ba2101fcc682 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 15 Jul 2021 06:39:06 +0900 Subject: [PATCH 19/30] remove log --- src/runtime/rocm/rocm_module.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index b4356c1fff2e..d3ccde738e16 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -167,7 +167,6 @@ class ROCMWrappedFunc { ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE, &packed_nbytes, HIP_LAUNCH_PARAM_END}; - LOG(INFO) << "dynamic shared mem size: " << wl.dyn_shmem_size; // HIP supports only extra_args. ROCM_DRIVER_CALL(hipModuleLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), From 2085bfe380073126c2bd80b4203352d2190a459e Mon Sep 17 00:00:00 2001 From: masa Date: Sat, 17 Jul 2021 20:51:48 +0900 Subject: [PATCH 20/30] remove kDynShared, dyn.shared -> shared.dyn --- src/runtime/thread_storage_scope.h | 27 +++++++++---------- src/target/llvm/codegen_amdgpu.cc | 2 +- src/target/llvm/codegen_nvptx.cc | 2 +- src/target/source/codegen_cuda.cc | 4 +-- .../lower_device_storage_access_info.cc | 2 +- src/tir/transforms/split_host_device.cc | 12 ++++++--- src/tir/transforms/storage_rewrite.cc | 6 ++--- tests/python/unittest/test_tir_ir_builder.py | 20 +++++++------- 8 files changed, 40 insertions(+), 35 deletions(-) diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 196d2f909ea0..eb20e333c296 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -42,23 +42,21 @@ enum class StorageRank { kGlobal = 0, /*! \brief shared memory among thread group */ kShared = 1, - /*! \brief dynamic shared memory among thread group */ - kDynShared = 2, /*! * \brief reserved for warp memory. * This is only used by programming model. * There is no such memory usually in GPU. * Instead, we can simulate it by registers and shuffle. */ - kWarp = 3, + kWarp = 2, /*! \brief thread local memory */ - kLocal = 4, + kLocal = 3, /*! \brief wmma scope memory of matrix_a */ - kWMMAMatrixA = 5, + kWMMAMatrixA = 4, /*! \brief wmma scope memory of matrix_b */ - kWMMAMatrixB = 6, + kWMMAMatrixB = 5, /*! \brief wmma scope memory of accumulator */ - kWMMAAccumulator = 7, + kWMMAAccumulator = 6, }; /*! @@ -98,8 +96,6 @@ struct StorageScope { return "global" + tag; case StorageRank::kShared: return "shared" + tag; - case StorageRank::kDynShared: - return "dyn.shared" + tag; case StorageRank::kWarp: return "warp" + tag; case StorageRank::kLocal: @@ -130,9 +126,6 @@ struct StorageScope { } else if (s.compare(0, 6, "shared") == 0) { r.rank = StorageRank::kShared; r.tag = s.substr(6, std::string::npos); - } else if (s.compare(0, 10, "dyn.shared") == 0) { - r.rank = StorageRank::kDynShared; - r.tag = s.substr(10, std::string::npos); } else if (s.compare(0, 4, "warp") == 0) { r.rank = StorageRank::kWarp; r.tag = s.substr(4, std::string::npos); @@ -205,8 +198,10 @@ struct ThreadWorkLoad { /*! \brief Thread axis configuration */ class ThreadAxisConfig { public: - void Init(size_t base, const std::vector& thread_axis_tags) { + void Init(size_t base, const std::vector& thread_axis_tags, + bool use_dyn_shared_memory = false) { base_ = base; + use_dyn_shared_memory_ = use_dyn_shared_memory; std::vector filled(6, false); for (size_t i = 0; i < thread_axis_tags.size(); ++i) { const std::string& tag = thread_axis_tags[i]; @@ -232,7 +227,9 @@ class ThreadAxisConfig { w.work_size[arg_index_map_[i]] = size; } } - w.dyn_shmem_size = static_cast(x.values[base_ + arg_index_map_.size()].v_int64); + if (use_dyn_shared_memory_) { + w.dyn_shmem_size = static_cast(x.values[base_ + arg_index_map_.size()].v_int64); + } return w; } // return the work dim @@ -245,6 +242,8 @@ class ThreadAxisConfig { size_t work_dim_; /*! \brief The index mapping. */ std::vector arg_index_map_; + /*! \brief Whether or not use dynamic shared memory. */ + bool use_dyn_shared_memory_{false}; }; } // namespace runtime diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 12bedfcdffea..7770e42086de 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -75,7 +75,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - if (storage_scope.rank == runtime::StorageRank::kDynShared) { + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { LOG(WARNING) << "Dynamic shared memory support for rocm is experimental."; buf = AllocateSharedMemory(op->dtype, 0, 3, std::min(info.alignment, 16), llvm::GlobalValue::ExternalLinkage); diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index bcd5c658684e..15543eda423f 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -55,7 +55,7 @@ class CodeGenNVPTX : public CodeGenLLVM { } auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - if (storage_scope.rank == runtime::StorageRank::kDynShared) { + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { // Shared memory: address space == 3 buf = AllocateSharedMemory(op->dtype, 0, 3, info.alignment, llvm::GlobalValue::ExternalLinkage); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index dc78b26f44d5..7897490730a3 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -525,7 +525,7 @@ void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) "all global arrays as input instead"; if (scope == "shared") { os << "__shared__ "; - } else if (scope == "dyn.shared") { + } else if (scope == "shared.dyn") { os << "extern __shared__ "; } } @@ -726,7 +726,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { PrintType(op->dtype, stream); } - if (scope == "dyn.shared") { + if (scope == "shared.dyn") { stream << ' ' << vid << "[];\n"; } else { int32_t constant_size = op->constant_allocation_size(); diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 829b7d822d11..eafed837cee3 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -67,7 +67,7 @@ class StorageAccessInfoLower : public StmtExprMutator { StorageScope scope = StorageScope::Create(op->value.as()->value); StorageEntry e; e.scope = scope; - if (scope.tag.length() != 0) { + if (scope.tag.length() != 0 && scope.tag != ".dyn") { e.info = GetMemoryInfo(op->value.as()->value); ICHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string(); } diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index ee1cb2e21f33..795ae9d6a73a 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -93,7 +93,7 @@ class VarUseDefAnalysis : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { this->HandleDef(op->buffer_var.get()); auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - if (storage_scope.rank == runtime::StorageRank::kDynShared) { + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; ICHECK_GT(op->extents.size(), 0); dyn_shmem_size_ = op->extents[0]; @@ -190,13 +190,13 @@ class VarUseDefAnalysis : public StmtExprMutator { Array thread_axis_; Array thread_extent_; PrimExpr dyn_shmem_size_{0}; + bool use_dyn_shmem_{false}; std::unordered_map use_count_; std::unordered_map def_count_; private: ExprDeepEqual deep_equal_; std::unordered_map let_binding_; - bool use_dyn_shmem_{false}; }; Array UndefinedVars(const Stmt& stmt, const Array& args) { @@ -278,6 +278,10 @@ class HostDeviceSplitter : public StmtMutator { WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, runtime::String(kernel_symbol)); device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); + if (m.use_dyn_shmem_) { + device_func = + WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1)); + } (*device_mod_)->Add(GlobalVar(kernel_symbol), device_func); // generate calls to the device function @@ -289,7 +293,9 @@ class HostDeviceSplitter : public StmtMutator { for (PrimExpr ext : m.thread_extent_) { call_args.push_back(ext); } - call_args.push_back(m.dyn_shmem_size_); + if (m.use_dyn_shmem_) { + call_args.push_back(m.dyn_shmem_size_); + } return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args)); } diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 613d02614b39..b216b8b848db 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -512,7 +512,7 @@ class StoragePlanRewriter : public StmtExprMutator { // try to find merge, for tagged memory for (size_t i = 0; i < vec.size(); ++i) { StorageEntry* e = vec[i]; - if (e->scope.tag.length() != 0) { + if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") { ICHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be const size"; for (size_t j = 0; j < i; ++j) { if (e->scope == vec[j]->scope) { @@ -546,7 +546,7 @@ class StoragePlanRewriter : public StmtExprMutator { make_const(DataType::Int(32), 1), e->allocs[0]->extents); e->new_alloc = Allocate(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate(0)); - if (e->scope.tag.length() != 0) { + if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) @@ -587,7 +587,7 @@ class StoragePlanRewriter : public StmtExprMutator { combo_size = analyzer_.Simplify(combo_size); e->new_alloc = Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0)); - if (e->scope.tag.length() != 0) { + if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index df8b6e997916..582b6cecb44a 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -511,7 +511,7 @@ def test_device_ir(A, B): tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", n) - temp = ib.allocate(dtype, (n,), scope="dyn.shared") # n is symbolic size + temp = ib.allocate(dtype, (n,), scope="shared.dyn") # n is symbolic size Aptr = ib.buffer_ptr(A) Bptr = ib.buffer_ptr(B) @@ -555,13 +555,13 @@ def check_target(target): if __name__ == "__main__": - test_prefetch() - test_if() - test_for() - test_cpu() - test_gpu() - test_while_vectorize() - test_while_collatz() - test_while_mandel() - test_while_binary_search() + # test_prefetch() + # test_if() + # test_for() + # test_cpu() + # test_gpu() + # test_while_vectorize() + # test_while_collatz() + # test_while_mandel() + # test_while_binary_search() test_dyn_shared() From 94b1a782c835ac0469f095100f95ff482524e556 Mon Sep 17 00:00:00 2001 From: masa Date: Sat, 17 Jul 2021 20:53:26 +0900 Subject: [PATCH 21/30] support backward compat --- include/tvm/tir/function.h | 9 +++++++++ src/runtime/cuda/cuda_module.cc | 9 ++++++--- src/runtime/meta_data.h | 1 + src/runtime/rocm/rocm_module.cc | 8 +++++--- src/target/build_common.h | 3 +++ 5 files changed, 24 insertions(+), 6 deletions(-) diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 2d65814ce34f..55f4fc62649c 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -244,6 +244,8 @@ namespace attr { * * Here n = len(arg), m = len(work_size) = len(device_thread_axis). * + * When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is omitted. + * * The list of device_thread_axis indicates how can be bind the * work_size arguments to the corresponding threads. * @@ -251,6 +253,13 @@ namespace attr { */ constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis"; +/*! + * \brief Whether or not use dynamic shared memory. + * + * Type: Integer + */ +constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory"; + /*! * \brief Whether to set noalias rule on the function arguments. * diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 7b4a3accc6c3..4da060f0ddd8 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -153,12 +153,13 @@ class CUDAWrappedFunc { public: // initialize the CUDA function. void Init(CUDAModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_void_args, const std::vector& thread_axis_tags) { + size_t num_void_args, const std::vector& thread_axis_tags, + bool use_dyn_shared_memory) { m_ = m; sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); - thread_axis_cfg_.Init(num_void_args, thread_axis_tags); + thread_axis_cfg_.Init(num_void_args, thread_axis_tags, use_dyn_shared_memory); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { @@ -169,6 +170,7 @@ class CUDAWrappedFunc { } CUstream strm = static_cast(CUDAThreadEntry::ThreadLocal()->stream); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + LOG(INFO) << "wl.dyn_shmem_size: " << wl.dyn_shmem_size; CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr); @@ -241,7 +243,8 @@ PackedFunc CUDAModuleNode::GetFunction(const std::string& name, if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; CUDAWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags, + info.use_dyn_shared_memory); return PackFuncVoidAddr(f, info.arg_types); } diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index e3ec155dc291..a7e860f16e4b 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -104,6 +104,7 @@ struct FunctionInfo { std::string name; std::vector arg_types; std::vector thread_axis_tags; + bool use_dyn_shared_memory{false}; void Save(dmlc::JSONWriter* writer) const; void Load(dmlc::JSONReader* reader); diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index d3ccde738e16..40463a0912d9 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -147,12 +147,13 @@ class ROCMWrappedFunc { public: // initialize the ROCM function. void Init(ROCMModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_void_args, const std::vector& thread_axis_tags) { + size_t num_void_args, const std::vector& thread_axis_tags, + bool use_dyn_shared_memory) { m_ = m; sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); - thread_axis_cfg_.Init(num_void_args, thread_axis_tags); + thread_axis_cfg_.Init(num_void_args, thread_axis_tags, use_dyn_shared_memory); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const { @@ -196,7 +197,8 @@ PackedFunc ROCMModuleNode::GetFunction(const std::string& name, if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; ROCMWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags, + info.use_dyn_shared_memory); return PackFuncPackedArg(f, info.arg_types); } diff --git a/src/target/build_common.h b/src/target/build_common.h index d2fe6468eef8..6c6ab995f7e2 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -56,6 +56,9 @@ inline std::unordered_map ExtractFuncInfo(co info.thread_axis_tags.push_back(thread_axis[i]->thread_tag); } } + if (auto opt = f->GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { + info.use_dyn_shared_memory = opt.value(); + } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); fmap[static_cast(global_symbol.value())] = info; } From 64a9d5e9e4c6a6829d4399b7ff605d10ebb0243d Mon Sep 17 00:00:00 2001 From: masa Date: Sat, 17 Jul 2021 21:07:09 +0900 Subject: [PATCH 22/30] update json/binary reader/writer --- src/runtime/cuda/cuda_module.cc | 1 - src/runtime/file_utils.cc | 4 ++++ tests/python/unittest/test_tir_ir_builder.py | 18 +++++++++--------- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 4da060f0ddd8..01f86df7ac4d 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -170,7 +170,6 @@ class CUDAWrappedFunc { } CUstream strm = static_cast(CUDAThreadEntry::ThreadLocal()->stream); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); - LOG(INFO) << "wl.dyn_shmem_size: " << wl.dyn_shmem_size; CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr); diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 32dd1d8020c9..88f4547515c3 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -44,6 +44,7 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const { writer->WriteObjectKeyValue("name", name); writer->WriteObjectKeyValue("arg_types", sarg_types); writer->WriteObjectKeyValue("thread_axis_tags", thread_axis_tags); + writer->WriteObjectKeyValue("use_dyn_shared_memory", use_dyn_shared_memory); writer->EndObject(); } @@ -53,6 +54,7 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { helper.DeclareField("name", &name); helper.DeclareField("arg_types", &sarg_types); helper.DeclareField("thread_axis_tags", &thread_axis_tags); + helper.DeclareOptionalField("use_dyn_shared_memory", &use_dyn_shared_memory); helper.ReadAllFields(reader); arg_types.resize(sarg_types.size()); for (size_t i = 0; i < arg_types.size(); ++i) { @@ -64,12 +66,14 @@ void FunctionInfo::Save(dmlc::Stream* writer) const { writer->Write(name); writer->Write(arg_types); writer->Write(thread_axis_tags); + writer->Write(use_dyn_shared_memory); } bool FunctionInfo::Load(dmlc::Stream* reader) { if (!reader->Read(&name)) return false; if (!reader->Read(&arg_types)) return false; if (!reader->Read(&thread_axis_tags)) return false; + if (!reader->Read(&use_dyn_shared_memory)) return false; return true; } diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 582b6cecb44a..0329134bb3fa 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -555,13 +555,13 @@ def check_target(target): if __name__ == "__main__": - # test_prefetch() - # test_if() - # test_for() - # test_cpu() - # test_gpu() - # test_while_vectorize() - # test_while_collatz() - # test_while_mandel() - # test_while_binary_search() + test_prefetch() + test_if() + test_for() + test_cpu() + test_gpu() + test_while_vectorize() + test_while_collatz() + test_while_mandel() + test_while_binary_search() test_dyn_shared() From c333cdcb0fb75fe93e511eaa98390bdf4f594165 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 20 Jul 2021 20:38:57 +0900 Subject: [PATCH 23/30] thread_axis_tags -> launch_param_tags --- docs/dev/codebase_walkthrough.rst | 2 +- src/runtime/cuda/cuda_module.cc | 6 +++--- src/runtime/file_utils.cc | 8 ++++---- src/runtime/meta_data.h | 2 +- src/runtime/metal/metal_module.mm | 6 +++--- src/runtime/opencl/opencl_module.cc | 6 +++--- src/runtime/rocm/rocm_module.cc | 6 +++--- src/runtime/thread_storage_scope.h | 6 +++--- src/runtime/vulkan/vulkan_wrapped_func.cc | 6 +++--- src/runtime/vulkan/vulkan_wrapped_func.h | 2 +- src/target/build_common.h | 2 +- web/src/webgpu.ts | 6 +++--- 12 files changed, 29 insertions(+), 29 deletions(-) diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index 60ab5e5ae9d2..ca5ebd8a47f2 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -183,7 +183,7 @@ The first time you invoke the compiled module with ``fadd(a, b, c)``, ``GetFunct auto it = fmap_.find(name); const FunctionInfo& info = it->second; CUDAWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); return PackFuncVoidAddr(f, info.arg_types); } diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 01f86df7ac4d..c61f6a614d66 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -153,13 +153,13 @@ class CUDAWrappedFunc { public: // initialize the CUDA function. void Init(CUDAModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_void_args, const std::vector& thread_axis_tags, + size_t num_void_args, const std::vector& launch_param_tags, bool use_dyn_shared_memory) { m_ = m; sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); - thread_axis_cfg_.Init(num_void_args, thread_axis_tags, use_dyn_shared_memory); + thread_axis_cfg_.Init(num_void_args, launch_param_tags, use_dyn_shared_memory); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { @@ -242,7 +242,7 @@ PackedFunc CUDAModuleNode::GetFunction(const std::string& name, if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; CUDAWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags, + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags, info.use_dyn_shared_memory); return PackFuncVoidAddr(f, info.arg_types); } diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 88f4547515c3..ea3f67e94fc2 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -43,7 +43,7 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("name", name); writer->WriteObjectKeyValue("arg_types", sarg_types); - writer->WriteObjectKeyValue("thread_axis_tags", thread_axis_tags); + writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags); writer->WriteObjectKeyValue("use_dyn_shared_memory", use_dyn_shared_memory); writer->EndObject(); } @@ -53,7 +53,7 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { std::vector sarg_types; helper.DeclareField("name", &name); helper.DeclareField("arg_types", &sarg_types); - helper.DeclareField("thread_axis_tags", &thread_axis_tags); + helper.DeclareField("launch_param_tags", &launch_param_tags); helper.DeclareOptionalField("use_dyn_shared_memory", &use_dyn_shared_memory); helper.ReadAllFields(reader); arg_types.resize(sarg_types.size()); @@ -65,14 +65,14 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { void FunctionInfo::Save(dmlc::Stream* writer) const { writer->Write(name); writer->Write(arg_types); - writer->Write(thread_axis_tags); + writer->Write(launch_param_tags); writer->Write(use_dyn_shared_memory); } bool FunctionInfo::Load(dmlc::Stream* reader) { if (!reader->Read(&name)) return false; if (!reader->Read(&arg_types)) return false; - if (!reader->Read(&thread_axis_tags)) return false; + if (!reader->Read(&launch_param_tags)) return false; if (!reader->Read(&use_dyn_shared_memory)) return false; return true; } diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index a7e860f16e4b..83c042b667d8 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -103,7 +103,7 @@ Module MetadataModuleCreate( struct FunctionInfo { std::string name; std::vector arg_types; - std::vector thread_axis_tags; + std::vector launch_param_tags; bool use_dyn_shared_memory{false}; void Save(dmlc::JSONWriter* writer) const; diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 88501880557e..bc1eccfe141f 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -178,7 +178,7 @@ void SaveToBinary(dmlc::Stream* stream) final { // initialize the METAL function. void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, - const std::vector& thread_axis_tags) { + const std::vector& launch_param_tags) { w_ = metal::MetalWorkspace::Global(); m_ = m; sptr_ = sptr; @@ -186,7 +186,7 @@ void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_na num_buffer_args_ = num_buffer_args; num_pack_args_ = num_pack_args; std::fill(scache_.begin(), scache_.end(), (id)nil); - thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); + thread_axis_cfg_.Init(num_buffer_args + num_pack_args, launch_param_tags); metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int dev_id = t->device.device_id; scache_[dev_id] = m->GetPipelineState(dev_id, func_name); @@ -261,7 +261,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons MetalWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, - info.thread_axis_tags); + info.launch_param_tags); pf = PackFuncNonBufferArg(f, info.arg_types); }; return pf; diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 4040d82b33e7..9ff897b09be0 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -40,14 +40,14 @@ class OpenCLWrappedFunc { // initialize the OpenCL function. void Init(OpenCLModuleNode* m, ObjectPtr sptr, OpenCLModuleNode::KTRefEntry entry, std::string func_name, std::vector arg_size, - const std::vector& thread_axis_tags) { + const std::vector& launch_param_tags) { w_ = m->GetGlobalWorkspace(); m_ = m; sptr_ = sptr; entry_ = entry; func_name_ = func_name; arg_size_ = arg_size; - thread_axis_cfg_.Init(arg_size.size(), thread_axis_tags); + thread_axis_cfg_.Init(arg_size.size(), launch_param_tags); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { @@ -148,7 +148,7 @@ PackedFunc OpenCLModuleNode::GetFunction(const std::string& name, } } // initialize the wrapped func. - f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.thread_axis_tags); + f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.launch_param_tags); return PackFuncVoidAddr(f, info.arg_types); } diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 40463a0912d9..86102a400d18 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -147,13 +147,13 @@ class ROCMWrappedFunc { public: // initialize the ROCM function. void Init(ROCMModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_void_args, const std::vector& thread_axis_tags, + size_t num_void_args, const std::vector& launch_param_tags, bool use_dyn_shared_memory) { m_ = m; sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); - thread_axis_cfg_.Init(num_void_args, thread_axis_tags, use_dyn_shared_memory); + thread_axis_cfg_.Init(num_void_args, launch_param_tags, use_dyn_shared_memory); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const { @@ -197,7 +197,7 @@ PackedFunc ROCMModuleNode::GetFunction(const std::string& name, if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; ROCMWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags, + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags, info.use_dyn_shared_memory); return PackFuncPackedArg(f, info.arg_types); } diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index eb20e333c296..c717dc453ab4 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -198,13 +198,13 @@ struct ThreadWorkLoad { /*! \brief Thread axis configuration */ class ThreadAxisConfig { public: - void Init(size_t base, const std::vector& thread_axis_tags, + void Init(size_t base, const std::vector& launch_param_tags, bool use_dyn_shared_memory = false) { base_ = base; use_dyn_shared_memory_ = use_dyn_shared_memory; std::vector filled(6, false); - for (size_t i = 0; i < thread_axis_tags.size(); ++i) { - const std::string& tag = thread_axis_tags[i]; + for (size_t i = 0; i < launch_param_tags.size(); ++i) { + const std::string& tag = launch_param_tags[i]; ThreadScope ts = ThreadScope::Create(tag); arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); filled[ts.rank * 3 + ts.dim_index] = true; diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index 103b2aa7692c..e64038e49f22 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -33,13 +33,13 @@ namespace vulkan { void VulkanWrappedFunc::Init(VulkanModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, - const std::vector& thread_axis_tags) { + const std::vector& launch_param_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; num_buffer_args_ = num_buffer_args; num_pack_args_ = num_pack_args; - thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); + thread_axis_cfg_.Init(num_buffer_args + num_pack_args, launch_param_tags); } void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, @@ -197,7 +197,7 @@ PackedFunc VulkanModuleNode::GetFunction(const std::string& name, VulkanWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, - info.thread_axis_tags); + info.launch_param_tags); return PackFuncNonBufferArg(std::move(f), info.arg_types); } diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index a174f22eba59..ffaaa085c610 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -58,7 +58,7 @@ class VulkanWrappedFunc { public: void Init(VulkanModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, - const std::vector& thread_axis_tags); + const std::vector& launch_param_tags); void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const; diff --git a/src/target/build_common.h b/src/target/build_common.h index 6c6ab995f7e2..e86d13c6e768 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -53,7 +53,7 @@ inline std::unordered_map ExtractFuncInfo(co if (auto opt = f->GetAttr>(tir::attr::kDeviceThreadAxis)) { auto thread_axis = opt.value(); for (size_t i = 0; i < thread_axis.size(); ++i) { - info.thread_axis_tags.push_back(thread_axis[i]->thread_tag); + info.launch_param_tags.push_back(thread_axis[i]->thread_tag); } } if (auto opt = f->GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts index f12837f421f8..226797eb7d19 100644 --- a/web/src/webgpu.ts +++ b/web/src/webgpu.ts @@ -39,7 +39,7 @@ export async function detectGPUDevice(): Promise { interface FunctionInfo { name: string; arg_types: Array; - thread_axis_tags: Array; + launch_param_tags: Array; } /** @@ -114,8 +114,8 @@ export class WebGPUContext { const dispatchToDim: Array = []; - for (let i = 0; i < finfo.thread_axis_tags.length; ++i) { - const tag: string = finfo.thread_axis_tags[i]; + for (let i = 0; i < finfo.launch_param_tags.length; ++i) { + const tag: string = finfo.launch_param_tags[i]; if (tag.startsWith("blockIdx.")) { const target: number = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0)); assert(target >= 0 && target < 3); From 19ec3099862693c17b5fc982226bcfb83c8b209b Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 20 Jul 2021 20:42:38 +0900 Subject: [PATCH 24/30] ThreadAxisConfig -> LaunchParamConfig --- docs/dev/codebase_walkthrough.rst | 2 +- src/runtime/cuda/cuda_module.cc | 6 +++--- src/runtime/metal/metal_module.mm | 6 +++--- src/runtime/opencl/opencl_module.cc | 8 ++++---- src/runtime/rocm/rocm_module.cc | 6 +++--- src/runtime/thread_storage_scope.h | 2 +- src/runtime/vulkan/vulkan_wrapped_func.cc | 4 ++-- src/runtime/vulkan/vulkan_wrapped_func.h | 2 +- 8 files changed, 18 insertions(+), 18 deletions(-) diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index ca5ebd8a47f2..efc8b32832c0 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -204,7 +204,7 @@ The ``PackedFunc``'s overloaded ``operator()`` will be called, which in turn cal fcache_[device_id] = m_->GetFunc(device_id, func_name_); } CUstream strm = static_cast(CUDAThreadEntry::ThreadLocal()->stream); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); CUresult result = cuLaunchKernel( fcache_[device_id], wl.grid_dim(0), diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index c61f6a614d66..046ccc3449bd 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -159,7 +159,7 @@ class CUDAWrappedFunc { sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); - thread_axis_cfg_.Init(num_void_args, launch_param_tags, use_dyn_shared_memory); + launch_param_config_.Init(num_void_args, launch_param_tags, use_dyn_shared_memory); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { @@ -169,7 +169,7 @@ class CUDAWrappedFunc { fcache_[device_id] = m_->GetFunc(device_id, func_name_); } CUstream strm = static_cast(CUDAThreadEntry::ThreadLocal()->stream); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr); @@ -203,7 +203,7 @@ class CUDAWrappedFunc { // mark as mutable, to enable lazy initialization mutable std::array fcache_; // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; + LaunchParamConfig launch_param_config_; }; class CUDAPrepGlobalBarrier { diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index bc1eccfe141f..0e792ad21ec7 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -186,7 +186,7 @@ void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_na num_buffer_args_ = num_buffer_args; num_pack_args_ = num_pack_args; std::fill(scache_.begin(), scache_.end(), (id)nil); - thread_axis_cfg_.Init(num_buffer_args + num_pack_args, launch_param_tags); + launch_param_config_.Init(num_buffer_args + num_pack_args, launch_param_tags); metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int dev_id = t->device.device_id; scache_[dev_id] = m->GetPipelineState(dev_id, func_name); @@ -201,7 +201,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons if (scache_[device_id] == nil) { scache_[device_id] = m_->GetPipelineState(device_id, func_name_); } - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2); auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup; CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup); @@ -243,7 +243,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons // mark as mutable, to enable lazy initialization mutable std::array, kMetalMaxNumDevice> scache_; // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; + LaunchParamConfig launch_param_config_; }; PackedFunc MetalModuleNode::GetFunction(const std::string& name, diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 9ff897b09be0..a7efca5cae57 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -47,7 +47,7 @@ class OpenCLWrappedFunc { entry_ = entry; func_name_ = func_name; arg_size_ = arg_size; - thread_axis_cfg_.Init(arg_size.size(), launch_param_tags); + launch_param_config_.Init(arg_size.size(), launch_param_tags); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { @@ -73,8 +73,8 @@ class OpenCLWrappedFunc { OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], arg)); } cl_command_queue queue = w_->GetQueue(t->device); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); - cl_uint work_dim = static_cast(thread_axis_cfg_.work_dim()); + ThreadWorkLoad wl = launch_param_config_.Extract(args); + cl_uint work_dim = static_cast(launch_param_config_.work_dim()); for (cl_uint i = 0; i < work_dim; ++i) { wl.work_size[i] *= wl.work_size[i + 3]; } @@ -97,7 +97,7 @@ class OpenCLWrappedFunc { // convert code for void argument std::vector arg_size_; // thread axis config - ThreadAxisConfig thread_axis_cfg_; + LaunchParamConfig launch_param_config_; }; OpenCLModuleNode::~OpenCLModuleNode() { diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 86102a400d18..affae370d9e3 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -153,7 +153,7 @@ class ROCMWrappedFunc { sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); - thread_axis_cfg_.Init(num_void_args, launch_param_tags, use_dyn_shared_memory); + launch_param_config_.Init(num_void_args, launch_param_tags, use_dyn_shared_memory); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const { @@ -165,7 +165,7 @@ class ROCMWrappedFunc { hipStream_t strm = static_cast(ROCMThreadEntry::ThreadLocal()->stream); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE, &packed_nbytes, HIP_LAUNCH_PARAM_END}; // HIP supports only extra_args. @@ -186,7 +186,7 @@ class ROCMWrappedFunc { // mark as mutable, to enable lazy initialization mutable std::array fcache_; // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; + LaunchParamConfig launch_param_config_; }; PackedFunc ROCMModuleNode::GetFunction(const std::string& name, diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index c717dc453ab4..aba25a92729e 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -196,7 +196,7 @@ struct ThreadWorkLoad { inline size_t grid_dim(size_t i) const { return work_size[i]; } }; /*! \brief Thread axis configuration */ -class ThreadAxisConfig { +class LaunchParamConfig { public: void Init(size_t base, const std::vector& launch_param_tags, bool use_dyn_shared_memory = false) { diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index e64038e49f22..0712f723bb64 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -39,7 +39,7 @@ void VulkanWrappedFunc::Init(VulkanModuleNode* m, ObjectPtr sptr, func_name_ = func_name; num_buffer_args_ = num_buffer_args; num_pack_args_ = num_pack_args; - thread_axis_cfg_.Init(num_buffer_args + num_pack_args, launch_param_tags); + launch_param_config_.Init(num_buffer_args + num_pack_args, launch_param_tags); } void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, @@ -50,7 +50,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_); } const auto& pipeline = scache_[device_id]; - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); std::vector descriptor_buffers; descriptor_buffers.resize(num_buffer_args_); for (size_t i = 0; i < num_buffer_args_; ++i) { diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index ffaaa085c610..d237d2ca8736 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -76,7 +76,7 @@ class VulkanWrappedFunc { // Device state cache per device. // mark as mutable, to enable lazy initialization // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; + LaunchParamConfig launch_param_config_; mutable std::array, kVulkanMaxNumDevice> scache_; }; From 8deab0b9cafbab443bff933cc2c02f786a0223df Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 20 Jul 2021 21:11:05 +0900 Subject: [PATCH 25/30] remove use_dynamic_shared_memory from FunctionInfo meta data --- src/runtime/cuda/cuda_module.cc | 8 +++---- src/runtime/file_utils.cc | 8 +++---- src/runtime/meta_data.h | 4 +++- src/runtime/rocm/rocm_module.cc | 8 +++---- src/runtime/thread_storage_scope.h | 10 ++++++--- src/target/build_common.h | 4 +++- tests/python/unittest/test_tir_ir_builder.py | 22 ++++++++++---------- 7 files changed, 33 insertions(+), 31 deletions(-) diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 046ccc3449bd..e3820d8ca569 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -153,13 +153,12 @@ class CUDAWrappedFunc { public: // initialize the CUDA function. void Init(CUDAModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_void_args, const std::vector& launch_param_tags, - bool use_dyn_shared_memory) { + size_t num_void_args, const std::vector& launch_param_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); - launch_param_config_.Init(num_void_args, launch_param_tags, use_dyn_shared_memory); + launch_param_config_.Init(num_void_args, launch_param_tags); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { @@ -242,8 +241,7 @@ PackedFunc CUDAModuleNode::GetFunction(const std::string& name, if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; CUDAWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags, - info.use_dyn_shared_memory); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); return PackFuncVoidAddr(f, info.arg_types); } diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index ea3f67e94fc2..35832e83f59c 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -44,7 +44,6 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const { writer->WriteObjectKeyValue("name", name); writer->WriteObjectKeyValue("arg_types", sarg_types); writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags); - writer->WriteObjectKeyValue("use_dyn_shared_memory", use_dyn_shared_memory); writer->EndObject(); } @@ -53,8 +52,9 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { std::vector sarg_types; helper.DeclareField("name", &name); helper.DeclareField("arg_types", &sarg_types); - helper.DeclareField("launch_param_tags", &launch_param_tags); - helper.DeclareOptionalField("use_dyn_shared_memory", &use_dyn_shared_memory); + helper.DeclareOptionalField("launch_param_tags", &launch_param_tags); + helper.DeclareOptionalField("thread_axis_tags", + &launch_param_tags); // for backward compatibility helper.ReadAllFields(reader); arg_types.resize(sarg_types.size()); for (size_t i = 0; i < arg_types.size(); ++i) { @@ -66,14 +66,12 @@ void FunctionInfo::Save(dmlc::Stream* writer) const { writer->Write(name); writer->Write(arg_types); writer->Write(launch_param_tags); - writer->Write(use_dyn_shared_memory); } bool FunctionInfo::Load(dmlc::Stream* reader) { if (!reader->Read(&name)) return false; if (!reader->Read(&arg_types)) return false; if (!reader->Read(&launch_param_tags)) return false; - if (!reader->Read(&use_dyn_shared_memory)) return false; return true; } diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 83c042b667d8..002012a1e1cc 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -99,12 +99,14 @@ Module MetadataModuleCreate( const std::unordered_map& metadata, const std::unordered_map>& sym_vars); +/*! \brief A tag to specify whether or not dynamic shared memory is used */ +constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; + /*! \brief function information needed by device */ struct FunctionInfo { std::string name; std::vector arg_types; std::vector launch_param_tags; - bool use_dyn_shared_memory{false}; void Save(dmlc::JSONWriter* writer) const; void Load(dmlc::JSONReader* reader); diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index affae370d9e3..e02a2c683fdb 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -147,13 +147,12 @@ class ROCMWrappedFunc { public: // initialize the ROCM function. void Init(ROCMModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_void_args, const std::vector& launch_param_tags, - bool use_dyn_shared_memory) { + size_t num_void_args, const std::vector& launch_param_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); - launch_param_config_.Init(num_void_args, launch_param_tags, use_dyn_shared_memory); + launch_param_config_.Init(num_void_args, launch_param_tags); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const { @@ -197,8 +196,7 @@ PackedFunc ROCMModuleNode::GetFunction(const std::string& name, if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; ROCMWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags, - info.use_dyn_shared_memory); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); return PackFuncPackedArg(f, info.arg_types); } diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index aba25a92729e..e8bbd87e1006 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -29,6 +29,8 @@ #include #include +#include "meta_data.h" + namespace tvm { namespace runtime { @@ -198,13 +200,15 @@ struct ThreadWorkLoad { /*! \brief Thread axis configuration */ class LaunchParamConfig { public: - void Init(size_t base, const std::vector& launch_param_tags, - bool use_dyn_shared_memory = false) { + void Init(size_t base, const std::vector& launch_param_tags) { base_ = base; - use_dyn_shared_memory_ = use_dyn_shared_memory; std::vector filled(6, false); for (size_t i = 0; i < launch_param_tags.size(); ++i) { const std::string& tag = launch_param_tags[i]; + if (tag == kUseDynamicSharedMemoryTag) { + use_dyn_shared_memory_ = true; + continue; + } ThreadScope ts = ThreadScope::Create(tag); arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); filled[ts.rank * 3 + ts.dim_index] = true; diff --git a/src/target/build_common.h b/src/target/build_common.h index e86d13c6e768..c66c2b52822e 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -57,7 +57,9 @@ inline std::unordered_map ExtractFuncInfo(co } } if (auto opt = f->GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { - info.use_dyn_shared_memory = opt.value(); + if (opt.value()) { + info.launch_param_tags.push_back(runtime::kUseDynamicSharedMemoryTag); + } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); fmap[static_cast(global_symbol.value())] = info; diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 0329134bb3fa..6874678a3892 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -538,8 +538,8 @@ def test_device_ir(A, B): s = te.create_schedule(B.op) def check_target(target): - if not tvm.testing.device_enabled(target): - return + # if not tvm.testing.device_enabled(target): + # return freduce = tvm.build(s, [A, B], target) dev = tvm.device(target, 0) @@ -555,13 +555,13 @@ def check_target(target): if __name__ == "__main__": - test_prefetch() - test_if() - test_for() - test_cpu() - test_gpu() - test_while_vectorize() - test_while_collatz() - test_while_mandel() - test_while_binary_search() + # test_prefetch() + # test_if() + # test_for() + # test_cpu() + # test_gpu() + # test_while_vectorize() + # test_while_collatz() + # test_while_mandel() + # test_while_binary_search() test_dyn_shared() From 7cbc700b88b02caba33e8cda6ab23944a4707163 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 20 Jul 2021 22:00:57 +0900 Subject: [PATCH 26/30] revert change in test_tir_ir_builder.py --- tests/python/unittest/test_tir_ir_builder.py | 22 ++++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 6874678a3892..0329134bb3fa 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -538,8 +538,8 @@ def test_device_ir(A, B): s = te.create_schedule(B.op) def check_target(target): - # if not tvm.testing.device_enabled(target): - # return + if not tvm.testing.device_enabled(target): + return freduce = tvm.build(s, [A, B], target) dev = tvm.device(target, 0) @@ -555,13 +555,13 @@ def check_target(target): if __name__ == "__main__": - # test_prefetch() - # test_if() - # test_for() - # test_cpu() - # test_gpu() - # test_while_vectorize() - # test_while_collatz() - # test_while_mandel() - # test_while_binary_search() + test_prefetch() + test_if() + test_for() + test_cpu() + test_gpu() + test_while_vectorize() + test_while_collatz() + test_while_mandel() + test_while_binary_search() test_dyn_shared() From 1151a525dbcb75b0ef8706ab4f0bed0d45ab7a7c Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 20 Jul 2021 22:28:17 +0900 Subject: [PATCH 27/30] make sure kUseDynamicSharedMemoryTag is the last tag --- src/runtime/thread_storage_scope.h | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index e8bbd87e1006..1b6d7e7b3c0b 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -206,12 +206,15 @@ class LaunchParamConfig { for (size_t i = 0; i < launch_param_tags.size(); ++i) { const std::string& tag = launch_param_tags[i]; if (tag == kUseDynamicSharedMemoryTag) { + ICHECK_EQ(i, launch_param_tags.size() - 1) + << "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags."; use_dyn_shared_memory_ = true; continue; + } else { + ThreadScope ts = ThreadScope::Create(tag); + arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); + filled[ts.rank * 3 + ts.dim_index] = true; } - ThreadScope ts = ThreadScope::Create(tag); - arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); - filled[ts.rank * 3 + ts.dim_index] = true; } work_dim_ = 1; for (int i = 0; i < 3; ++i) { From 587e5b6cbba14a678e4449c0efc4ad4a879996e4 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 20 Jul 2021 22:30:00 +0900 Subject: [PATCH 28/30] remove continue --- src/runtime/thread_storage_scope.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 1b6d7e7b3c0b..6043b311c095 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -209,7 +209,6 @@ class LaunchParamConfig { ICHECK_EQ(i, launch_param_tags.size() - 1) << "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags."; use_dyn_shared_memory_ = true; - continue; } else { ThreadScope ts = ThreadScope::Create(tag); arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); From ffc138a9413c6a79cb73eb75754512ec659a3d35 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 20 Jul 2021 22:31:43 +0900 Subject: [PATCH 29/30] update doc string following name change --- src/runtime/thread_storage_scope.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 6043b311c095..4a49fc2f17ca 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -197,7 +197,7 @@ struct ThreadWorkLoad { */ inline size_t grid_dim(size_t i) const { return work_size[i]; } }; -/*! \brief Thread axis configuration */ +/*! \brief Launch parameters configuration */ class LaunchParamConfig { public: void Init(size_t base, const std::vector& launch_param_tags) { From b8c05a5418a440e6fe404897a20bdce21c95bd6b Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 21 Jul 2021 15:37:37 +0900 Subject: [PATCH 30/30] more comment update following name change --- src/runtime/cuda/cuda_module.cc | 2 +- src/runtime/metal/metal_module.mm | 2 +- src/runtime/opencl/opencl_module.cc | 2 +- src/runtime/rocm/rocm_module.cc | 2 +- src/runtime/thread_storage_scope.h | 2 +- src/runtime/vulkan/vulkan_wrapped_func.h | 5 ++--- 6 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index e3820d8ca569..7d6879a62aba 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -201,7 +201,7 @@ class CUDAWrappedFunc { // Device function cache per device. // mark as mutable, to enable lazy initialization mutable std::array fcache_; - // thread axis configuration + // launch parameters configuration LaunchParamConfig launch_param_config_; }; diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 0e792ad21ec7..1e81ac1bbb34 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -242,7 +242,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons // Device state cache per device. // mark as mutable, to enable lazy initialization mutable std::array, kMetalMaxNumDevice> scache_; - // thread axis configuration + // launch parameters configuration LaunchParamConfig launch_param_config_; }; diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index a7efca5cae57..f6c7f6232819 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -96,7 +96,7 @@ class OpenCLWrappedFunc { std::string func_name_; // convert code for void argument std::vector arg_size_; - // thread axis config + // launch parameters config LaunchParamConfig launch_param_config_; }; diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index e02a2c683fdb..487ad23e16b9 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -184,7 +184,7 @@ class ROCMWrappedFunc { // Device function cache per device. // mark as mutable, to enable lazy initialization mutable std::array fcache_; - // thread axis configuration + // launch parameters configuration LaunchParamConfig launch_param_config_; }; diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 4a49fc2f17ca..ac8260ffbe39 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -19,7 +19,7 @@ /*! * \file thread_storage_scope.h - * \brief Extract thread axis configuration from TVMArgs. + * \brief Extract launch parameters configuration from TVMArgs. */ #ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index d237d2ca8736..cd4774bf0f5a 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -73,11 +73,10 @@ class VulkanWrappedFunc { size_t num_buffer_args_; // number of packed arguments. size_t num_pack_args_; + // launch parameters configuration + LaunchParamConfig launch_param_config_; // Device state cache per device. // mark as mutable, to enable lazy initialization - // thread axis configuration - LaunchParamConfig launch_param_config_; - mutable std::array, kVulkanMaxNumDevice> scache_; };