From 57d0ee208da9cf6e7e0fe40a82b65fb035e4b614 Mon Sep 17 00:00:00 2001 From: Wei Pan Date: Thu, 7 May 2020 08:57:35 -0700 Subject: [PATCH] [TOPI] Improve CUDA softmax scheduling - Do not use multiple kernels - Schedule with warp reductions - Fixed a bug on the lower warp memory pass - Fixed warp shuffle intrinsics for the nvptx backend. Signed-off-by: Wei Pan --- src/target/llvm/codegen_llvm.cc | 52 +++++++++++++++++++ src/target/llvm/llvm_common.h | 1 + src/tir/transforms/lower_warp_memory.cc | 10 ++-- .../test_tir_transform_lower_warp_memory.py | 36 +++++++++++++ topi/python/topi/cuda/softmax.py | 51 +++++++++++++++++- 5 files changed, 146 insertions(+), 4 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index b43e9889a4ae..5c2c41aeece1 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -736,7 +736,40 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type #endif // TVM_LLVM_VERSION } +// Check if this is a warp shuffle intrinsic call and match its +// corresponding nvvm intrinsic. Return true if the match is successful. +static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) { + // Only 32 bit data type is supported. + if (op->dtype.is_vector() || op->dtype.bits() != 32) { + return false; + } + + // Intrinsic lookup table. + // It is difficult to emit _sync verion that works on Pascal. + // We ignore the mask and only emit the non-sync version for nvptx. + llvm::Intrinsic::ID ids[] = { + llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32, + llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32, + llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32}; + + int offset = 0; + if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) { + offset = 0; + } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) { + offset = 2; + } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) { + offset = 4; + } else { + return false; + } + + *id = ids[offset + op->dtype.is_float()]; + return true; +} + llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { + llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic; + if (op->is_intrinsic("llvm_intrin")) { CHECK_GE(op->args.size(), 2U); llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); @@ -781,6 +814,25 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { return CreateStorageSync(op); + } else if (GetWarpShuffleIntrinsic(op, &id)) { + std::vector arg_value; + std::vector arg_type; + // Ignore the first mask operand and remove the last + // redundant warp_size.. + size_t n_args = op->args.size() - 1; + for (size_t i = 1; i < n_args; ++i) { + arg_value.push_back(MakeValue(op->args[i])); + arg_type.push_back(arg_value.back()->getType()); + } + llvm::Type* return_type = arg_type[0]; + llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type); + return builder_->CreateCall(func, arg_value); + } else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) { + // Only nvptx target may keep this intrinsic at this point. + // PTX assembly: asm "activemask.b32 r1;" + auto fty = llvm::FunctionType::get(t_int32_, false); + auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true); + return builder_->CreateCall(val); } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h index 49389fe82ac0..529ee7485fde 100644 --- a/src/target/llvm/llvm_common.h +++ b/src/target/llvm/llvm_common.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 4c8dec01245c..91879b6a4b82 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -213,9 +213,13 @@ class WarpAccessRewriter : protected StmtExprMutator { alloc_size *= op->dtype.lanes(); std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body); warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body); - CHECK_EQ(alloc_size % (width_ * warp_coeff_), 0) - << "Warp memory must be multiple of the extent of threadIdx.x"; - warp_group_ = alloc_size / (width_ * warp_coeff_); + + // Align the local memory size. The number of elements may not + // be a multiple of width_ * warp_coeff_; round it up. + int factor = width_ * warp_coeff_; + warp_group_ = (alloc_size + (factor - 1)) / factor; + alloc_size = warp_group_ * factor; + return AllocateNode::make(op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / width_)}, op->condition, this->VisitStmt(op->body)); diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index c3cf28982cdc..ce9dd5653457 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -218,9 +218,45 @@ def check_cuda(dtype): check_cuda("float32") check_cuda("float16") +def test_lower_warp_memory_roundup(): + if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): + print("skip because cuda is not enabled..") + return + + def check(m): + A = te.placeholder((m,), name='A') + B = te.compute((m,), lambda i: A[i] + 1, name='B') + + with tvm.target.create("cuda"): + s = te.create_schedule(B.op) + xo, xi = s[B].split(B.op.axis[0], factor=32) + tx = te.thread_axis("threadIdx.x") + s[B].bind(xo, te.thread_axis("blockIdx.x")) + s[B].bind(xi, tx) + + AA = s.cache_read(A, "warp", [B]) + _, yi = s[AA].split(s[AA].op.axis[0], factor=32) + s[AA].bind(yi, tx) + s[AA].compute_at(s[B], xo) + + ctx = tvm.gpu(0) + func = tvm.build(s, [A, B], "cuda") + A_np = np.random.uniform(size=(m,)).astype(A.dtype) + B_np = np.zeros(shape=(m,)).astype(B.dtype) + A_nd = tvm.nd.array(A_np, ctx) + B_nd = tvm.nd.array(B_np, ctx) + func(A_nd, B_nd) + B_np = A_np + 1 + tvm.testing.assert_allclose(B_nd.asnumpy(), B_np) + + check(m=31) + check(m=32) + check(m=33) + if __name__ == "__main__": test_lower_warp_memory_local_scope() test_lower_warp_memory_correct_indices() test_lower_warp_memory_cuda_end_to_end() test_lower_warp_memory_cuda_half_a_warp() test_lower_warp_memory_cuda_2_buffers() + test_lower_warp_memory_roundup() diff --git a/topi/python/topi/cuda/softmax.py b/topi/python/topi/cuda/softmax.py index 62c437ae96ac..6142f4832763 100644 --- a/topi/python/topi/cuda/softmax.py +++ b/topi/python/topi/cuda/softmax.py @@ -16,12 +16,12 @@ # under the License. # pylint: disable=invalid-name, unused-variable, trailing-whitespace """Schedule for softmax operator""" +from tvm import target as target_ from tvm import te from tvm.contrib import cudnn from .. import generic from .injective import schedule_injective_from_existing - def schedule_softmax(outs): """Schedule for softmax op. @@ -39,6 +39,7 @@ def schedule_softmax(outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) softmax = outs[0] + tgt = target_.Target.current(allow_none=False) op_tag = softmax.op.tag if op_tag == 'softmax_output': @@ -53,6 +54,14 @@ def schedule_softmax(outs): raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \ Got {0}'.format(op_tag)) + # The nvptx backend only supports 32-bits warp shuffle instructions. + # + # TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend. + def sched_warp_softmax(): + if tgt.target_name == "nvptx": + return softmax.dtype == "float32" or softmax.dtype == "int32" + return True + if len(softmax.shape) > 2: ops = [max_elem.op, expsum.op, softmax.op] if exp is not None: @@ -60,6 +69,46 @@ def schedule_softmax(outs): for op in ops: s = schedule_injective_from_existing(s, op.output(0)) + + elif sched_warp_softmax(): + # A warp of 32 threads performs a row reduction. + num_thread = tgt.thread_warp_size + block_x = te.thread_axis("blockIdx.x") + thread_x = te.thread_axis((0, num_thread), "threadIdx.x") + + # (4) softmax + xo, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread) + _, xii = s[softmax].split(xi, factor=4) + s[softmax].vectorize(xii) + s[softmax].bind(xo, thread_x) + s[softmax].bind(softmax.op.axis[0], block_x) + + # (3) expsum + k = expsum.op.reduce_axis[0] + ko, _ = s[expsum].split(k, nparts=num_thread) + s[expsum].bind(ko, thread_x) + s[expsum].compute_at(s[softmax], xo) + + # (2) exp + if exp is not None: + xo, xi = s[exp].split(exp.op.axis[1], nparts=num_thread) + _, xii = s[exp].split(xi, factor=4) + s[exp].vectorize(xii) + s[exp].bind(xo, thread_x) + s[exp].compute_at(s[expsum], expsum.op.axis[0]) + s[exp].compute_at(s[softmax], softmax.op.axis[0]) + s[exp].set_scope("warp") + + # (1) max_elem + k = max_elem.op.reduce_axis[0] + ko, _ = s[max_elem].split(k, nparts=num_thread) + s[max_elem].bind(ko, thread_x) + if exp is not None: + s[max_elem].compute_at(s[exp], xo) + else: + s[max_elem].bind(ko, thread_x) + s[max_elem].bind(max_elem.op.axis[0], block_x) + else: num_thread = 64 block_x = te.thread_axis("blockIdx.x")