Skip to content

Commit

Permalink
[TOPI] Improve CUDA softmax scheduling (apache#5600)
Browse files Browse the repository at this point in the history
- Do not use multiple kernels

- Schedule with warp reductions

- Fixed a bug on the lower warp memory pass

- Fixed warp shuffle intrinsics for the nvptx backend.

Signed-off-by: Wei Pan <weip@nvidia.com>
  • Loading branch information
wpan11nv authored and trevor-m committed Jun 18, 2020
1 parent 16e23cd commit 5deb83b
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 4 deletions.
52 changes: 52 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,40 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type
#endif // TVM_LLVM_VERSION
}

// Check if this is a warp shuffle intrinsic call and match its
// corresponding nvvm intrinsic. Return true if the match is successful.
static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) {
// Only 32 bit data type is supported.
if (op->dtype.is_vector() || op->dtype.bits() != 32) {
return false;
}

// Intrinsic lookup table.
// It is difficult to emit _sync verion that works on Pascal.
// We ignore the mask and only emit the non-sync version for nvptx.
llvm::Intrinsic::ID ids[] = {
llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32,
llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32,
llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32};

int offset = 0;
if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) {
offset = 0;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) {
offset = 2;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) {
offset = 4;
} else {
return false;
}

*id = ids[offset + op->dtype.is_float()];
return true;
}

llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic;

if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
Expand Down Expand Up @@ -781,6 +814,25 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
}
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return CreateStorageSync(op);
} else if (GetWarpShuffleIntrinsic(op, &id)) {
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
// Ignore the first mask operand and remove the last
// redundant warp_size..
size_t n_args = op->args.size() - 1;
for (size_t i = 1; i < n_args; ++i) {
arg_value.push_back(MakeValue(op->args[i]));
arg_type.push_back(arg_value.back()->getType());
}
llvm::Type* return_type = arg_type[0];
llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type);
return builder_->CreateCall(func, arg_value);
} else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) {
// Only nvptx target may keep this intrinsic at this point.
// PTX assembly: asm "activemask.b32 r1;"
auto fty = llvm::FunctionType::get(t_int32_, false);
auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true);
return builder_->CreateCall(val);
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const LoadNode* l = op->args[0].as<LoadNode>();
CHECK(op->args.size() == 1 && l);
Expand Down
1 change: 1 addition & 0 deletions src/target/llvm/llvm_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <llvm/Analysis/TargetTransformInfo.h>
#include <llvm/Bitcode/BitcodeWriter.h>
#include <llvm/ExecutionEngine/MCJIT.h>
#include <llvm/IR/InlineAsm.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/IR/Value.h>
#include <llvm/Support/SourceMgr.h>
Expand Down
10 changes: 7 additions & 3 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,13 @@ class WarpAccessRewriter : protected StmtExprMutator {
alloc_size *= op->dtype.lanes();
std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body);
warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body);
CHECK_EQ(alloc_size % (width_ * warp_coeff_), 0)
<< "Warp memory must be multiple of the extent of threadIdx.x";
warp_group_ = alloc_size / (width_ * warp_coeff_);

// Align the local memory size. The number of elements may not
// be a multiple of width_ * warp_coeff_; round it up.
int factor = width_ * warp_coeff_;
warp_group_ = (alloc_size + (factor - 1)) / factor;
alloc_size = warp_group_ * factor;

return AllocateNode::make(op->buffer_var, op->dtype,
{make_const(DataType::Int(32), alloc_size / width_)}, op->condition,
this->VisitStmt(op->body));
Expand Down
36 changes: 36 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_warp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,45 @@ def check_cuda(dtype):
check_cuda("float32")
check_cuda("float16")

def test_lower_warp_memory_roundup():
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return

def check(m):
A = te.placeholder((m,), name='A')
B = te.compute((m,), lambda i: A[i] + 1, name='B')

with tvm.target.create("cuda"):
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=32)
tx = te.thread_axis("threadIdx.x")
s[B].bind(xo, te.thread_axis("blockIdx.x"))
s[B].bind(xi, tx)

AA = s.cache_read(A, "warp", [B])
_, yi = s[AA].split(s[AA].op.axis[0], factor=32)
s[AA].bind(yi, tx)
s[AA].compute_at(s[B], xo)

ctx = tvm.gpu(0)
func = tvm.build(s, [A, B], "cuda")
A_np = np.random.uniform(size=(m,)).astype(A.dtype)
B_np = np.zeros(shape=(m,)).astype(B.dtype)
A_nd = tvm.nd.array(A_np, ctx)
B_nd = tvm.nd.array(B_np, ctx)
func(A_nd, B_nd)
B_np = A_np + 1
tvm.testing.assert_allclose(B_nd.asnumpy(), B_np)

check(m=31)
check(m=32)
check(m=33)

if __name__ == "__main__":
test_lower_warp_memory_local_scope()
test_lower_warp_memory_correct_indices()
test_lower_warp_memory_cuda_end_to_end()
test_lower_warp_memory_cuda_half_a_warp()
test_lower_warp_memory_cuda_2_buffers()
test_lower_warp_memory_roundup()
51 changes: 50 additions & 1 deletion topi/python/topi/cuda/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
# under the License.
# pylint: disable=invalid-name, unused-variable, trailing-whitespace
"""Schedule for softmax operator"""
from tvm import target as target_
from tvm import te
from tvm.contrib import cudnn
from .. import generic
from .injective import schedule_injective_from_existing


def schedule_softmax(outs):
"""Schedule for softmax op.
Expand All @@ -39,6 +39,7 @@ def schedule_softmax(outs):
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
softmax = outs[0]
tgt = target_.Target.current(allow_none=False)

op_tag = softmax.op.tag
if op_tag == 'softmax_output':
Expand All @@ -53,13 +54,61 @@ def schedule_softmax(outs):
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag))

# The nvptx backend only supports 32-bits warp shuffle instructions.
#
# TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend.
def sched_warp_softmax():
if tgt.target_name == "nvptx":
return softmax.dtype == "float32" or softmax.dtype == "int32"
return True

if len(softmax.shape) > 2:
ops = [max_elem.op, expsum.op, softmax.op]
if exp is not None:
ops.append(exp.op)

for op in ops:
s = schedule_injective_from_existing(s, op.output(0))

elif sched_warp_softmax():
# A warp of 32 threads performs a row reduction.
num_thread = tgt.thread_warp_size
block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis((0, num_thread), "threadIdx.x")

# (4) softmax
xo, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread)
_, xii = s[softmax].split(xi, factor=4)
s[softmax].vectorize(xii)
s[softmax].bind(xo, thread_x)
s[softmax].bind(softmax.op.axis[0], block_x)

# (3) expsum
k = expsum.op.reduce_axis[0]
ko, _ = s[expsum].split(k, nparts=num_thread)
s[expsum].bind(ko, thread_x)
s[expsum].compute_at(s[softmax], xo)

# (2) exp
if exp is not None:
xo, xi = s[exp].split(exp.op.axis[1], nparts=num_thread)
_, xii = s[exp].split(xi, factor=4)
s[exp].vectorize(xii)
s[exp].bind(xo, thread_x)
s[exp].compute_at(s[expsum], expsum.op.axis[0])
s[exp].compute_at(s[softmax], softmax.op.axis[0])
s[exp].set_scope("warp")

# (1) max_elem
k = max_elem.op.reduce_axis[0]
ko, _ = s[max_elem].split(k, nparts=num_thread)
s[max_elem].bind(ko, thread_x)
if exp is not None:
s[max_elem].compute_at(s[exp], xo)
else:
s[max_elem].bind(ko, thread_x)
s[max_elem].bind(max_elem.op.axis[0], block_x)

else:
num_thread = 64
block_x = te.thread_axis("blockIdx.x")
Expand Down

0 comments on commit 5deb83b

Please sign in to comment.