From 8ab4a30812325e3be2b32c47b5d3a369f9e6e7e0 Mon Sep 17 00:00:00 2001 From: Wei Pan Date: Tue, 28 Apr 2020 15:27:26 -0700 Subject: [PATCH] [Optimization] Warp level reduction support for CUDA - Added the warp level reduction support - Upgraded shfl intrinsics to the sync version. - This is the building block for scheduling softmax like operations. Signed-off-by: Wei Pan --- include/tvm/tir/expr.h | 27 +- src/target/source/codegen_cuda.cc | 19 +- src/target/source/codegen_cuda.h | 2 + src/target/source/intrin_rule_cuda.cc | 37 ++- src/target/source/intrin_rule_opencl.cc | 10 +- src/target/source/literal/cuda_half_t.h | 14 + src/tir/transforms/lower_thread_allreduce.cc | 310 ++++++++++++++++--- src/tir/transforms/lower_warp_memory.cc | 8 +- tests/python/integration/test_reduce.py | 102 ++++++ 9 files changed, 462 insertions(+), 67 deletions(-) diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index bf0d4f985a92..afa9414fd97d 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1234,22 +1234,43 @@ constexpr const char *tvm_call_trace_packed_lowered = * } */ constexpr const char* tvm_storage_sync = "tvm_storage_sync"; + /*! * \brief See pseudo code * - * Type tvm_warp_shuffle(Type value, warp_id, width, warp_size) { - * return (value passed in by warp indicated by warp_id); + * Type tvm_warp_shuffle(mask, Type value, warp_id, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id); + * } + * + * Type tvm_warp_shuffle_up(mask, Type value, offset, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id - offset); + * } + * + * Type tvm_warp_shuffle_down(mask, Type value, offset, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id + offset); + * } + * + * unsigned tvm_warp_activemask() { + * return (32-bit mask of currently active threads in the calling warp); * } * * Parameter warp_id indicates the source thread ID in a warp. * + * Parameter offset indicates the relative distance to this_warp_id. + * * Parameter width indicates the number of threads involved in one - * shuffle. See CUDA document for __shfl. + * shuffle. See CUDA document for __shfl_sync, __shfl_up_sync, + * __shfl_down_sync and __activemask. * * Parameter warp_size is the size of a warp, which helps a backend * to determine wheter the width paramter is legal. + * */ constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle"; +constexpr const char* tvm_warp_shuffle_up = "tvm_warp_shuffle_up"; +constexpr const char* tvm_warp_shuffle_down = "tvm_warp_shuffle_down"; +constexpr const char* tvm_warp_activemask = "tvm_warp_activemask"; + /*! * \brief Initialize the global barrier. * Call this at beginning of kernel that need global barrier. diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index a911e6bf13d6..591e4d0b421c 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -64,6 +64,10 @@ std::string CodeGenCUDA::Finish() { decl_stream << _cuda_half_util; } + if (enable_warp_shuffle_) { + decl_stream << _cuda_warp_intrinsic_util; + } + if (enable_int8_) { decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n"; decl_stream << "#include \n"; @@ -269,6 +273,11 @@ void CodeGenCUDA::PrintVecBinaryOp( void CodeGenCUDA::PrintVecElemLoad( const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*) + if (t.is_scalar()) { + os << vec; + return; + } + static const char access[] = {'x', 'y', 'z', 'w'}; CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4)); if ((t.is_int()) && t.bits() == 8) { @@ -395,7 +404,15 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { os << sret; } -void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { +void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { + // This is only for backward compatibility with __shfl_{up/down}. + // A macro will be used to replace *_sync calls to legacy ones. + if (op->is_intrinsic("__shfl_sync") || + op->is_intrinsic("__shfl_up_sync") || + op->is_intrinsic("__shfl_down_sync")) { + enable_warp_shuffle_ = true; + } + if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 6U); diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index d1db7047b1b6..ed17638d7dc7 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -88,6 +88,8 @@ class CodeGenCUDA final : public CodeGenC { bool enable_fp16_{false}; // whether enable int8 bool enable_int8_{false}; + // whether enable warp shuffle intrinsics + bool enable_warp_shuffle_{false}; // whether need math_constants.h bool need_math_constants_h_{false}; // whether need mma.h diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index f40dd5e86bad..47425c3414ac 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -81,14 +81,34 @@ struct CUDAPopcount { } }; + +struct CUDAWarpIntrinsic { + const char* operator()(DataType t, const std::string& name) const { + if (name == intrinsic::tvm_warp_shuffle) { + return "__shfl_sync"; + } + if (name == intrinsic::tvm_warp_shuffle_up) { + return "__shfl_up_sync"; + } + if (name == intrinsic::tvm_warp_shuffle_down) { + return "__shfl_down_sync"; + } + if (name == intrinsic::tvm_warp_activemask) { + return "__activemask"; + } + return ""; + } +}; + +template static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; const CallNode* call = e.as(); CHECK(call != nullptr); - CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size + CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size Array cuda_args{{call->args[0], call->args[1], call->args[2]}}; - *rv = CallNode::make( - call->dtype, "__shfl", cuda_args, CallNode::PureExtern); + const char* name = T()(call->dtype, call->name); + *rv = CallNode::make(call->dtype, name, cuda_args, CallNode::PureExtern); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor") @@ -158,7 +178,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount") .set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle") -.set_body(DispatchCUDAShuffle); +.set_body(DispatchCUDAShuffle); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_up") +.set_body(DispatchCUDAShuffle); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_down") +.set_body(DispatchCUDAShuffle); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask") +.set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod") .set_body(DispatchExtern); diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 7374e6d40032..d7f63a671316 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -94,13 +94,13 @@ static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; const CallNode* call = e.as(); CHECK(call != nullptr); - CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size + CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size arith::Analyzer analyzer; - CHECK(analyzer.CanProve(call->args[2] == call->args[3])) + CHECK(analyzer.CanProve(call->args[3] == call->args[4])) << "Intel warp shuffle dose not support width != warp_size"; - Array cuda_args{{call->args[0], call->args[1]}}; - *rv = CallNode::make( - call->dtype, "intel_sub_group_shuffle", cuda_args, CallNode::PureExtern); + Array opencl_args{{call->args[1], call->args[2]}}; + *rv = CallNode::make(call->dtype, "intel_sub_group_shuffle", + opencl_args, CallNode::PureExtern); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle") diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index 858ac8572a08..baf4ba733dce 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -295,4 +295,18 @@ __pack_half2(const half x, const half y) { } )"; +static constexpr const char* _cuda_warp_intrinsic_util = R"( +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700) +#define __shfl_sync(mask, var, lane, width) \ + __shfl((var), (lane), (width)) + +#define __shfl_down_sync(mask, var, offset, width) \ + __shfl_down((var), (offset), (width)) + +#define __shfl_up_sync(mask, var, offset, width) \ + __shfl_up((var), (offset), (width)) +#endif + +)"; + #endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_ diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 9cb817d04b6d..11e420b39873 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -39,8 +39,8 @@ namespace tir { class ThreadAllreduceBuilder final : public StmtExprMutator { public: - explicit ThreadAllreduceBuilder(int warp_size) - : warp_size_(warp_size) {} + explicit ThreadAllreduceBuilder(const TargetNode* target) + : target_(target), warp_size_(target->thread_warp_size) {} Stmt VisitStmt_(const AttrStmtNode *op) final { if (op->attr_key == attr::thread_extent) { @@ -84,15 +84,22 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { auto it = alloc_remap_.find(op->buffer_var.get()); if (it != alloc_remap_.end()) { const AllocateNode* repl = it->second.as(); - // use volatile access to shared buffer. - stmt = AttrStmtNode::make( - repl->buffer_var, attr::volatile_scope, 1, op->body); - stmt = AllocateNode::make( - repl->buffer_var, repl->dtype, - repl->extents, repl->condition, stmt); - stmt = AttrStmtNode::make( - repl->buffer_var, attr::storage_scope, - StringImmNode::make("shared"), stmt); + if (warp_allocs_.count(repl)) { + stmt = AllocateNode::make(repl->buffer_var, repl->dtype, + repl->extents, repl->condition, op->body); + stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, + StringImmNode::make("local"), stmt); + } else { + // use volatile access to shared buffer. + stmt = AttrStmtNode::make( + repl->buffer_var, attr::volatile_scope, 1, op->body); + stmt = AllocateNode::make( + repl->buffer_var, repl->dtype, + repl->extents, repl->condition, stmt); + stmt = AttrStmtNode::make( + repl->buffer_var, attr::storage_scope, + StringImmNode::make("shared"), stmt); + } return stmt; } else { return stmt; @@ -119,6 +126,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return scope.dim_index < other.scope.dim_index; } }; + // make allreduce. Stmt MakeAllreduce(const CallNode* call) { CHECK(!reduce_combiner_.empty()); @@ -131,7 +139,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Array inits = combiner->identity_element; std::vector values(size); std::vector types(size); - PrimExpr cond = call->args[size+1]; + PrimExpr cond = call->args[size+1]; for (size_t idx = 0; idx < size; ++idx) { values[idx] = call->args[1+idx]; if (!is_one(cond)) { @@ -181,52 +189,196 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::sort(vpar.begin(), vpar.end()); // the size of each index. int reduce_extent, group_extent; - int threadx_extent = 1; PrimExpr reduce_index = FlattenThread(vred, &reduce_extent); PrimExpr group_index = FlattenThread(vpar, &group_extent); - if (reduce_extent == 1) { - // special case, no reduction is needed. - std::vector stores(size); + std::vector seq; + std::vector shared_bufs(size); + std::vector local_vars; + // + // This is an optimization. For small reduction sizes, it may be beneficial + // for a single warp to performance the entire reduction. No trips to shared + // memory and no cross warp synchronizations are required. + // The following code emits the reduction as follows: + // + // Allocate reduction vars v[i], i = 0..size-1 + // + // for offset from 16 to 1 by 2 + // + // a <- load(v[i]) + // b <- shuffle_down(load(v[i], offset)) + // v[i] <- reduction(a, b) + // + // broadcast results from lane 0 to all other lanes and store + // the final reduction result to the proper location. + // + if (is_warp_reduction(types)) { + // TODO(tvm-team) sub-warp reduction support. + CHECK_EQ(reduce_extent, warp_size_) << "not a warp reduction"; + // + // This is the index to the reduction variable, one reduction + // variable per warp. Local scope seems easier to reason without + // relying on a pattern match pass to fix it later. + PrimExpr index(0); + + for (size_t idx = 0; idx < size; ++idx) { + shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle()); + PrimExpr pred = const_true(types[idx].lanes()); + seq.emplace_back(StoreNode::make(shared_bufs[idx], values[idx], index, pred)); + + // Uses a local variable to store the shuffled data. + // Later on, this allocation will be properly attached to this statement. + Var var("t" + std::to_string(idx), types[idx]); + Stmt s = AllocateNode::make(var, var.dtype(), {PrimExpr(1)}, pred, + EvaluateNode::make(0)); + local_vars.push_back(s); + } + + // The mask for this reducer, as this reducer may sit inside + // a divergent control flow. Here it uses a variable to cache the current + // active channels. + // + Var mask_var("mask", DataType::UInt(32)); + { + PrimExpr pred = const_true(1); + PrimExpr mask = CallNode::make(DataType::UInt(32), + intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic); + seq.emplace_back(StoreNode::make(mask_var, mask, index, pred)); + // Push allocation with an empty body. Later this will be fixed + // when the entire body is ready. + auto stmt = AllocateNode::make(mask_var, mask_var->dtype, + {PrimExpr(1)}, pred, EvaluateNode::make(0)); + local_vars.push_back(stmt); + } + + // Emit reductions within a warp. + for (int offset = 16; offset > 0; offset /= 2) { + // Load reduction values, no synchronization needed. + Array a, b; + for (size_t i = 0; i < size; ++i) { + Var var = shared_bufs[i]; + PrimExpr pred = const_true(types[i].lanes()); + PrimExpr val = LoadNode::make(types[i], var, index, pred); + a.push_back(val); + + // __shfl_*sync calls shall not appear in if_then_else expressions + // as this is causing extra divergency. E.g. + // + // v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0); + // + // behaves differently from + // + // int t = __shfl_sync(mask, v1, 0); + // v1 = (v2 < v3) ? v3 : t; + // + // The former may cause dead lock as there is a divergent + // branch with a warp sync call inside. + // + const char* shfl_func = intrinsic::tvm_warp_shuffle_down; + PrimExpr other = WarpShuffle(shfl_func, mask_var, val, offset); + const AllocateNode* repl = local_vars[i].as(); + Stmt s = StoreNode::make(repl->buffer_var, other, index, pred); + seq.push_back(s); + + PrimExpr load = LoadNode::make(types[i], repl->buffer_var, index, pred); + b.push_back(load); + } + + // Do reductions. + Array ret = (*combiner)(a, b); + + // Store the reduction result to itself. + std::vector stores(size); + for (size_t i = 0; i < size; ++i) { + Var var = shared_bufs[i]; + PrimExpr pred = const_true(types[i].lanes()); + stores[i] = StoreNode::make(var, ret[i], index, pred); + } + seq.push_back(SeqStmt::Flatten(stores)); + } + + // Broadcast the reduction result from lane 0 to all other lanes. + // This avoids to emit predicated stores, as all threads are + // uniformmly writting the same result. + // for (size_t i = 0; i < size; ++i) { + Var var = shared_bufs[i]; PrimExpr pred = const_true(types[i].lanes()); - Var buffer_var = Downcast(call->args[2+size+i]); - stores[i] = StoreNode::make(buffer_var, values[i], 0, pred); + const char* shfl_func = intrinsic::tvm_warp_shuffle; + PrimExpr val = LoadNode::make(types[i], var, index, pred); + PrimExpr splat = WarpShuffle(shfl_func, mask_var, val, 0); + seq.push_back(StoreNode::make(var, splat, index, pred)); + } + + // Update existing allocations. + for (size_t i = 0; i < size; ++i) { + CHECK(!load_remap_.count(buffers[i])); + PrimExpr pred = const_true(types[i].lanes()); + Var var = shared_bufs[i]; + load_remap_[buffers[i]] = LoadNode::make(types[i], var, index, pred); + Array extents{PrimExpr(1)}; + auto node = AllocateNode::make(var, types[i], extents, pred, + EvaluateNode::make(0)); + alloc_remap_[buffers[i]] = node; + warp_allocs_.insert(node.get()); + } + } else { + int threadx_extent = 1; + if (reduce_extent == 1) { + // special case, no reduction is needed. + std::vector stores(size); + for (size_t i = 0; i < size; ++i) { + PrimExpr pred = const_true(types[i].lanes()); + Var buffer_var = Downcast(call->args[2+size+i]); + stores[i] = StoreNode::make(buffer_var, values[i], 0, pred); + } + return SeqStmt::Flatten(stores); + } + // Whether the threadIdx.x is involved in reduction. + if (vred[0].scope.dim_index == 0) { + threadx_extent = vred[0].extent; + } + // This sync is necessary because there might be incomplete read of + // previous iteration on the same buffer. + seq.emplace_back(SyncThread("shared")); + for (size_t idx = 0; idx < size; ++idx) { + shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle()); + PrimExpr pred = const_true(types[idx].lanes()); + seq.emplace_back(StoreNode::make( + shared_bufs[idx], values[idx], + BufIndex(reduce_index, group_index, reduce_extent), pred)); + } + seq.emplace_back(SyncThread("shared")); + seq.emplace_back(MakeBufAllreduce( + combiner, types, shared_bufs, + reduce_index, group_index, reduce_extent, threadx_extent)); + for (size_t idx = 0; idx < size; ++idx) { + CHECK(!load_remap_.count(buffers[idx])); + PrimExpr pred = const_true(types[idx].lanes()); + load_remap_[buffers[idx]] = LoadNode::make( + types[idx], shared_bufs[idx], + BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); + alloc_remap_[buffers[idx]] = AllocateNode::make( + shared_bufs[idx], types[idx], + {PrimExpr(group_extent), PrimExpr(reduce_extent)}, + pred, EvaluateNode::make(0)); } - return SeqStmt::Flatten(stores); - } - // Whether the threadIdx.x is involved in reduction. - if (vred[0].scope.dim_index == 0) { - threadx_extent = vred[0].extent; - } - std::vector seq; - std::vector shared_bufs(size); - // This sync is necessary because there might be incomplete read of - // previous iteration on the same buffer. - seq.emplace_back(SyncThread("shared")); - for (size_t idx = 0; idx < size; ++idx) { - shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle()); - PrimExpr pred = const_true(types[idx].lanes()); - seq.emplace_back(StoreNode::make( - shared_bufs[idx], values[idx], - BufIndex(reduce_index, group_index, reduce_extent), pred)); } - seq.emplace_back(SyncThread("shared")); - seq.emplace_back(MakeBufAllreduce( - combiner, types, shared_bufs, - reduce_index, group_index, reduce_extent, threadx_extent)); - for (size_t idx = 0; idx < size; ++idx) { - CHECK(!load_remap_.count(buffers[idx])); - PrimExpr pred = const_true(types[idx].lanes()); - load_remap_[buffers[idx]] = LoadNode::make( - types[idx], shared_bufs[idx], - BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); - alloc_remap_[buffers[idx]] = AllocateNode::make( - shared_bufs[idx], types[idx], - {PrimExpr(group_extent), PrimExpr(reduce_extent)}, - pred, EvaluateNode::make(0)); + + // Fix all local allocations as all statements are built. + Stmt body = SeqStmt::Flatten(seq); + for (auto var : local_vars) { + const AllocateNode* repl = var.as(); + if (repl) { + body = AllocateNode::make(repl->buffer_var, repl->dtype, + repl->extents, repl->condition, body); + body = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, + StringImmNode::make("local"), body); + } } - return SeqStmt::Flatten(seq); + + return body; } + // make allreduce. Stmt MakeBufAllreduce(const CommReducerNode *combiner, const std::vector& types, @@ -330,6 +482,59 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { {StringImmNode::make(sync)}, CallNode::Intrinsic)); } + + // Emit warp shuffle intrinsic calls. + PrimExpr WarpShuffle(const char* name, Var mask_var, PrimExpr val, + int delta_or_lane) { + PrimExpr pred = const_true(1); + PrimExpr index(0); + PrimExpr mask = LoadNode::make(DataType::UInt(32), mask_var, index, pred); + PrimExpr width = IntImm(DataType::Int(32), warp_size_); + Array args{mask, val, IntImm(DataType::Int(32), delta_or_lane), + width, width}; + return CallNode::make(val.dtype(), name, args, CallNode::Intrinsic); + } + + // Check if this is a reduction on threadIdx.x and its extent matches + // the warp size. + // + // TODO(tvm-team) reduction with a sub-warp of 8 or 16 threads. + bool is_warp_reduction(const std::vector& types) const { + // Only cuda target supports warp reductions. + if (target_->target_name != "cuda") return false; + + // Supported types: + // {u}int, {u}long, {u}long long, float, double, half/half2 + if (std::any_of(types.begin(), types.end(), [](DataType ty) { + if (ty.is_float16()) return ty.lanes() > 2; + if (ty.is_vector()) return true; + return ty.bytes() < 4 || ty.bytes() > 8; + })) { + return false; + } + if (thread_extents_.empty()) { + return false; + } + + const AttrStmtNode* op = thread_extents_.back(); + DCHECK_EQ(op->attr_key, attr::thread_extent); + + IterVar iv = Downcast(op->node); + ThreadEntry e; + e.scope = runtime::ThreadScope::make(iv->thread_tag); + e.extent = 0; + if (auto ptr = op->value.as()) { + e.extent = static_cast(ptr->value); + } + + return e.extent == warp_size_ && + e.scope.dim_index == 0 && + e.scope.rank == 1; + } + + // The target. + const TargetNode* target_ = nullptr; + // The warp size of the device. int warp_size_{1}; @@ -340,6 +545,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::unordered_map load_remap_; // Allocate remap std::unordered_map alloc_remap_; + // Allocate from warp reductions + std::unordered_set warp_allocs_; // Internal analyzer arith::Analyzer analyzer_; }; @@ -352,7 +559,8 @@ Pass LowerThreadAllreduce() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; - n->body = ThreadAllreduceBuilder(target.value()->thread_warp_size)(n->body); + const TargetNode* target_node = target.as(); + n->body = ThreadAllreduceBuilder(target_node)(n->body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 516b96cd9c15..0abbe765edc0 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -265,10 +265,12 @@ class WarpAccessRewriter : protected StmtExprMutator { << op->index << " local_index=" << local_index; PrimExpr load_value = LoadNode::make( op->dtype, op->buffer_var, local_index, op->predicate); + PrimExpr mask = CallNode::make(DataType::UInt(32), + intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic); return CallNode::make(load_value.dtype(), - intrinsic::tvm_warp_shuffle, - {load_value, group, width_, warp_size_}, - CallNode::Intrinsic); + intrinsic::tvm_warp_shuffle, + {mask, load_value, group, width_, warp_size_}, + CallNode::Intrinsic); } else { return StmtExprMutator::VisitExpr_(op); } diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index 82ade4478bea..7ac3496c994f 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -338,6 +338,106 @@ def check_target(device): check_target("cuda") check_target("vulkan") +def test_warp_reduction1(): + nthx = 32 + nthy = 4 + block_x = te.thread_axis("blockIdx.x") + thread_x = te.thread_axis((0, nthx), "threadIdx.x") + thread_y = te.thread_axis((0, nthy), "threadIdx.y") + + def check_target(device, m, n): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("skip because %s is not enabled.." % device) + return + + # compute + A = te.placeholder((m, n), name='A') + k = te.reduce_axis((0, n)) + B = te.compute((m,), lambda i: te.max(A[i][k], axis=k), name='B') + s = te.create_schedule(B.op) + + # schedule + k = s[B].op.reduce_axis[0] + ko, _ = s[B].split(k, nparts=nthx) + s[B].bind(ko, thread_x) + xo, xi = s[B].split(s[B].op.axis[0], factor=nthy) + s[B].bind(xi, thread_y) + s[B].bind(xo, block_x) + + print(tvm.lower(s, [A, B], simple_mode=True)) + + # validation + func = tvm.build(s, [A, B], "cuda", name="warp_reduction") + a_np = np.random.uniform(size=(m,n)).astype(A.dtype) + b_np = np.zeros((m,), dtype=A.dtype) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + b_np = np.max(a_np, axis=1) + func(a, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3) + + check_target("cuda", m=32, n=256) + check_target("cuda", m=10, n=20) + # This is a bug in normal reduction. + # check_target("cuda", m=10, n=37) + +def test_warp_reduction2(): + def fcombine(x, y): + return x[0] + y[0], x[1] * y[1] + + def fidentity(t0, t1): + return tvm.tir.const(0, t0), tvm.tir.const(1, t1) + + add_mul_reducer = te.comm_reducer(fcombine, fidentity, name='add_mul_reducer') + + # compute + m = 16 + n = 256 + A0 = te.placeholder((m, n), name='A0', dtype='float32') + A1 = te.placeholder((m, n), name='Al', dtype='float32') + k = te.reduce_axis((0, n), 'k') + T0, T1 = te.compute((m, ), lambda i: \ + add_mul_reducer((A0[i, k], A1[i, k]), axis=k), name='T') + + nthdx, nthdy = 32, 2 + block_x = te.thread_axis("blockIdx.x") + thread_x = te.thread_axis((0, nthdx), "threadIdx.x") + thread_y = te.thread_axis((0, nthdy), "threadIdx.y") + + def check_target(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("skip because %s is not enabled.." % device) + return + + # schedule + s = te.create_schedule(T0.op) + ko, _ = s[T0].split(k, nparts=nthdx) + xo, xi = s[T0].split(s[T0].op.axis[0], factor=nthdy) + s[T0].bind(ko, thread_x) + s[T0].bind(xi, thread_y) + s[T0].bind(xo, block_x) + + # validation + ctx = tvm.context(device, 0) + a0_np = np.random.uniform(size=(m,n)).astype(A0.dtype) + a1_np = np.random.uniform(size=(m,n)).astype(A1.dtype) + t0_np = np.zeros((m,), dtype=A0.dtype) + t1_np = np.zeros((m,), dtype=A1.dtype) + a0 = tvm.nd.array(a0_np, ctx) + a1 = tvm.nd.array(a1_np, ctx) + t0 = tvm.nd.array(t0_np, ctx) + t1 = tvm.nd.array(t1_np, ctx) + func = tvm.build(s, [A0, A1, T0, T1], device, name="reduction") + func(a0, a1, t0, t1) + t0_np = np.sum(a0_np, axis=1) + t1_np = np.product(a1_np, axis=1) + tvm.testing.assert_allclose(t0.asnumpy(), t0_np, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(t1.asnumpy(), t1_np, rtol=1e-3, atol=1e-3) + + check_target("cuda") + if __name__ == "__main__": test_rfactor_elemwise_threads() test_rfactor_threads() @@ -346,3 +446,5 @@ def check_target(device): test_reduce_prims() test_argmax() test_rfactor_argmax() + test_warp_reduction1() + test_warp_reduction2()