Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimization] Warp level reduction support for CUDA #5498

Merged
merged 1 commit into from
May 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1234,22 +1234,43 @@ constexpr const char *tvm_call_trace_packed_lowered =
* }
*/
constexpr const char* tvm_storage_sync = "tvm_storage_sync";

/*!
* \brief See pseudo code
*
* Type tvm_warp_shuffle(Type value, warp_id, width, warp_size) {
* return (value passed in by warp indicated by warp_id);
* Type tvm_warp_shuffle(mask, Type value, warp_id, width, warp_size) {
* return (value passed in by warp indicated by this_warp_id);
* }
*
* Type tvm_warp_shuffle_up(mask, Type value, offset, width, warp_size) {
* return (value passed in by warp indicated by this_warp_id - offset);
* }
*
* Type tvm_warp_shuffle_down(mask, Type value, offset, width, warp_size) {
* return (value passed in by warp indicated by this_warp_id + offset);
* }
*
* unsigned tvm_warp_activemask() {
* return (32-bit mask of currently active threads in the calling warp);
* }
*
* Parameter warp_id indicates the source thread ID in a warp.
*
* Parameter offset indicates the relative distance to this_warp_id.
*
* Parameter width indicates the number of threads involved in one
* shuffle. See CUDA document for __shfl.
* shuffle. See CUDA document for __shfl_sync, __shfl_up_sync,
* __shfl_down_sync and __activemask.
*
* Parameter warp_size is the size of a warp, which helps a backend
* to determine wheter the width paramter is legal.
*
*/
constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle";
constexpr const char* tvm_warp_shuffle_up = "tvm_warp_shuffle_up";
constexpr const char* tvm_warp_shuffle_down = "tvm_warp_shuffle_down";
constexpr const char* tvm_warp_activemask = "tvm_warp_activemask";

/*!
* \brief Initialize the global barrier.
* Call this at beginning of kernel that need global barrier.
Expand Down
19 changes: 18 additions & 1 deletion src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ std::string CodeGenCUDA::Finish() {
decl_stream << _cuda_half_util;
}

if (enable_warp_shuffle_) {
decl_stream << _cuda_warp_intrinsic_util;
}

if (enable_int8_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n";
decl_stream << "#include <sm_61_intrinsics.h>\n";
Expand Down Expand Up @@ -269,6 +273,11 @@ void CodeGenCUDA::PrintVecBinaryOp(

void CodeGenCUDA::PrintVecElemLoad(
const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*)
if (t.is_scalar()) {
os << vec;
return;
}

static const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
if ((t.is_int()) && t.bits() == 8) {
Expand Down Expand Up @@ -395,7 +404,15 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
os << sret;
}

void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
// This is only for backward compatibility with __shfl_{up/down}.
// A macro will be used to replace *_sync calls to legacy ones.
if (op->is_intrinsic("__shfl_sync") ||
op->is_intrinsic("__shfl_up_sync") ||
op->is_intrinsic("__shfl_down_sync")) {
enable_warp_shuffle_ = true;
}

if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 6U);
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class CodeGenCUDA final : public CodeGenC {
bool enable_fp16_{false};
// whether enable int8
bool enable_int8_{false};
// whether enable warp shuffle intrinsics
bool enable_warp_shuffle_{false};
// whether need math_constants.h
bool need_math_constants_h_{false};
// whether need mma.h
Expand Down
37 changes: 33 additions & 4 deletions src/target/source/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,34 @@ struct CUDAPopcount {
}
};


struct CUDAWarpIntrinsic {
const char* operator()(DataType t, const std::string& name) const {
if (name == intrinsic::tvm_warp_shuffle) {
return "__shfl_sync";
}
if (name == intrinsic::tvm_warp_shuffle_up) {
return "__shfl_up_sync";
}
if (name == intrinsic::tvm_warp_shuffle_down) {
return "__shfl_down_sync";
}
if (name == intrinsic::tvm_warp_activemask) {
return "__activemask";
}
return "";
}
};

template <typename T>
static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size
CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2]}};
*rv = CallNode::make(
call->dtype, "__shfl", cuda_args, CallNode::PureExtern);
const char* name = T()(call->dtype, call->name);
*rv = CallNode::make(call->dtype, name, cuda_args, CallNode::PureExtern);
}

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
Expand Down Expand Up @@ -158,7 +178,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
.set_body(DispatchExtern<CUDAPopcount>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
.set_body(DispatchCUDAShuffle);
.set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_up")
.set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_down")
.set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask")
.set_body(DispatchExtern<CUDAWarpIntrinsic>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod")
.set_body(DispatchExtern<CUDAMath>);
Expand Down
10 changes: 5 additions & 5 deletions src/target/source/intrin_rule_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size
CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
arith::Analyzer analyzer;
CHECK(analyzer.CanProve(call->args[2] == call->args[3]))
CHECK(analyzer.CanProve(call->args[3] == call->args[4]))
<< "Intel warp shuffle dose not support width != warp_size";
Array<PrimExpr> cuda_args{{call->args[0], call->args[1]}};
*rv = CallNode::make(
call->dtype, "intel_sub_group_shuffle", cuda_args, CallNode::PureExtern);
Array<PrimExpr> opencl_args{{call->args[1], call->args[2]}};
*rv = CallNode::make(call->dtype, "intel_sub_group_shuffle",
opencl_args, CallNode::PureExtern);
}

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle")
Expand Down
14 changes: 14 additions & 0 deletions src/target/source/literal/cuda_half_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,4 +295,18 @@ __pack_half2(const half x, const half y) {
}
)";

static constexpr const char* _cuda_warp_intrinsic_util = R"(
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)
#define __shfl_sync(mask, var, lane, width) \
__shfl((var), (lane), (width))

#define __shfl_down_sync(mask, var, offset, width) \
__shfl_down((var), (offset), (width))

#define __shfl_up_sync(mask, var, offset, width) \
__shfl_up((var), (offset), (width))
#endif

)";

#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
Loading