Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Improve CUDA softmax scheduling #5600

Merged
merged 1 commit into from
May 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,40 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type
#endif // TVM_LLVM_VERSION
}

// Check if this is a warp shuffle intrinsic call and match its
// corresponding nvvm intrinsic. Return true if the match is successful.
static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) {
// Only 32 bit data type is supported.
if (op->dtype.is_vector() || op->dtype.bits() != 32) {
wpan11nv marked this conversation as resolved.
Show resolved Hide resolved
return false;
}

// Intrinsic lookup table.
// It is difficult to emit _sync verion that works on Pascal.
// We ignore the mask and only emit the non-sync version for nvptx.
llvm::Intrinsic::ID ids[] = {
llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32,
llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32,
llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32};

int offset = 0;
if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) {
offset = 0;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) {
offset = 2;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) {
offset = 4;
} else {
return false;
}

*id = ids[offset + op->dtype.is_float()];
return true;
}

llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic;

if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
Expand Down Expand Up @@ -781,6 +814,25 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
}
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return CreateStorageSync(op);
} else if (GetWarpShuffleIntrinsic(op, &id)) {
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
// Ignore the first mask operand and remove the last
// redundant warp_size..
size_t n_args = op->args.size() - 1;
for (size_t i = 1; i < n_args; ++i) {
arg_value.push_back(MakeValue(op->args[i]));
arg_type.push_back(arg_value.back()->getType());
}
llvm::Type* return_type = arg_type[0];
llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type);
return builder_->CreateCall(func, arg_value);
} else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) {
// Only nvptx target may keep this intrinsic at this point.
// PTX assembly: asm "activemask.b32 r1;"
auto fty = llvm::FunctionType::get(t_int32_, false);
auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true);
return builder_->CreateCall(val);
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const LoadNode* l = op->args[0].as<LoadNode>();
CHECK(op->args.size() == 1 && l);
Expand Down
1 change: 1 addition & 0 deletions src/target/llvm/llvm_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <llvm/Analysis/TargetTransformInfo.h>
#include <llvm/Bitcode/BitcodeWriter.h>
#include <llvm/ExecutionEngine/MCJIT.h>
#include <llvm/IR/InlineAsm.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/IR/Value.h>
#include <llvm/Support/SourceMgr.h>
Expand Down
10 changes: 7 additions & 3 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,13 @@ class WarpAccessRewriter : protected StmtExprMutator {
alloc_size *= op->dtype.lanes();
std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body);
warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body);
CHECK_EQ(alloc_size % (width_ * warp_coeff_), 0)
<< "Warp memory must be multiple of the extent of threadIdx.x";
warp_group_ = alloc_size / (width_ * warp_coeff_);

// Align the local memory size. The number of elements may not
// be a multiple of width_ * warp_coeff_; round it up.
int factor = width_ * warp_coeff_;
warp_group_ = (alloc_size + (factor - 1)) / factor;
alloc_size = warp_group_ * factor;

wpan11nv marked this conversation as resolved.
Show resolved Hide resolved
wpan11nv marked this conversation as resolved.
Show resolved Hide resolved
return AllocateNode::make(op->buffer_var, op->dtype,
{make_const(DataType::Int(32), alloc_size / width_)}, op->condition,
this->VisitStmt(op->body));
Expand Down
36 changes: 36 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_warp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,45 @@ def check_cuda(dtype):
check_cuda("float32")
check_cuda("float16")

def test_lower_warp_memory_roundup():
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return

def check(m):
A = te.placeholder((m,), name='A')
B = te.compute((m,), lambda i: A[i] + 1, name='B')

with tvm.target.create("cuda"):
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=32)
tx = te.thread_axis("threadIdx.x")
s[B].bind(xo, te.thread_axis("blockIdx.x"))
s[B].bind(xi, tx)

AA = s.cache_read(A, "warp", [B])
_, yi = s[AA].split(s[AA].op.axis[0], factor=32)
s[AA].bind(yi, tx)
s[AA].compute_at(s[B], xo)

ctx = tvm.gpu(0)
func = tvm.build(s, [A, B], "cuda")
A_np = np.random.uniform(size=(m,)).astype(A.dtype)
B_np = np.zeros(shape=(m,)).astype(B.dtype)
A_nd = tvm.nd.array(A_np, ctx)
B_nd = tvm.nd.array(B_np, ctx)
func(A_nd, B_nd)
B_np = A_np + 1
tvm.testing.assert_allclose(B_nd.asnumpy(), B_np)

check(m=31)
check(m=32)
check(m=33)

if __name__ == "__main__":
test_lower_warp_memory_local_scope()
test_lower_warp_memory_correct_indices()
test_lower_warp_memory_cuda_end_to_end()
test_lower_warp_memory_cuda_half_a_warp()
test_lower_warp_memory_cuda_2_buffers()
test_lower_warp_memory_roundup()
51 changes: 50 additions & 1 deletion topi/python/topi/cuda/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
# under the License.
# pylint: disable=invalid-name, unused-variable, trailing-whitespace
"""Schedule for softmax operator"""
from tvm import target as target_
from tvm import te
from tvm.contrib import cudnn
from .. import generic
from .injective import schedule_injective_from_existing


def schedule_softmax(outs):
"""Schedule for softmax op.

Expand All @@ -39,6 +39,7 @@ def schedule_softmax(outs):
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
softmax = outs[0]
tgt = target_.Target.current(allow_none=False)
Copy link
Member

@tqchen tqchen Jun 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to register the warp level strategies only when the target is cuda, given that the "gpu" schedule is reused by other GPUs that does not support warp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. I'd like to benefit from the GPU schedules with warps...


op_tag = softmax.op.tag
if op_tag == 'softmax_output':
Expand All @@ -53,13 +54,61 @@ def schedule_softmax(outs):
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag))

# The nvptx backend only supports 32-bits warp shuffle instructions.
#
# TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend.
def sched_warp_softmax():
if tgt.target_name == "nvptx":
return softmax.dtype == "float32" or softmax.dtype == "int32"
return True
wpan11nv marked this conversation as resolved.
Show resolved Hide resolved

if len(softmax.shape) > 2:
ops = [max_elem.op, expsum.op, softmax.op]
if exp is not None:
ops.append(exp.op)

for op in ops:
s = schedule_injective_from_existing(s, op.output(0))

elif sched_warp_softmax():
# A warp of 32 threads performs a row reduction.
num_thread = tgt.thread_warp_size
block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis((0, num_thread), "threadIdx.x")

# (4) softmax
xo, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread)
_, xii = s[softmax].split(xi, factor=4)
s[softmax].vectorize(xii)
s[softmax].bind(xo, thread_x)
s[softmax].bind(softmax.op.axis[0], block_x)

# (3) expsum
k = expsum.op.reduce_axis[0]
ko, _ = s[expsum].split(k, nparts=num_thread)
s[expsum].bind(ko, thread_x)
s[expsum].compute_at(s[softmax], xo)

# (2) exp
if exp is not None:
xo, xi = s[exp].split(exp.op.axis[1], nparts=num_thread)
_, xii = s[exp].split(xi, factor=4)
s[exp].vectorize(xii)
wpan11nv marked this conversation as resolved.
Show resolved Hide resolved
s[exp].bind(xo, thread_x)
s[exp].compute_at(s[expsum], expsum.op.axis[0])
s[exp].compute_at(s[softmax], softmax.op.axis[0])
s[exp].set_scope("warp")

# (1) max_elem
k = max_elem.op.reduce_axis[0]
ko, _ = s[max_elem].split(k, nparts=num_thread)
s[max_elem].bind(ko, thread_x)
if exp is not None:
s[max_elem].compute_at(s[exp], xo)
else:
s[max_elem].bind(ko, thread_x)
s[max_elem].bind(max_elem.op.axis[0], block_x)

else:
num_thread = 64
block_x = te.thread_axis("blockIdx.x")
Expand Down