From 7339910eb9fb92413e3d6e424138e800f5bc8c3a Mon Sep 17 00:00:00 2001 From: Yining Shi Date: Mon, 15 Jul 2024 19:35:16 -0700 Subject: [PATCH 01/23] [tl] Add reg hint in ws pipeline. --- python/tvm/tl/engine.py | 2 +- python/tvm/tl/utils.py | 15 +++++++++++++-- src/tl/op/builtin.cc | 4 ++++ src/tl/op/builtin.h | 8 ++++++++ src/tl/target/codegen.cc | 12 ++++++++---- src/tl/tl_templates/copy_sm90.h | 10 ++++++++++ src/tl/transform/warp_specialized_pipeline.cc | 11 +++++++++-- 7 files changed, 53 insertions(+), 9 deletions(-) diff --git a/python/tvm/tl/engine.py b/python/tvm/tl/engine.py index acde0fb4fbcb..b76fae3c7b84 100644 --- a/python/tvm/tl/engine.py +++ b/python/tvm/tl/engine.py @@ -47,7 +47,7 @@ def tvm_callback_cuda_compile(code, target): format = "cubin" else: arch = [f"-arch=sm_{compute_version}"] - format = "ptx" + format = "cubin" ptx = nvcc.compile_cuda( code, diff --git a/python/tvm/tl/utils.py b/python/tvm/tl/utils.py index f7b67681ea9c..365093c4a047 100644 --- a/python/tvm/tl/utils.py +++ b/python/tvm/tl/utils.py @@ -143,9 +143,20 @@ def assert_consistent(self, repeat=10): for lhs, rhs in zip(lib_outs, ref_outs): assert torch.allclose(lhs, rhs), ["result is not consistent", lhs, rhs] - def run_once(self): + def run_once(self, func=None): + import ctypes + libcuda = ctypes.CDLL("libcuda.so") + ins = self._get_inputs() - return self.__call__(*ins) + if not func: + func = self.__call__ + + libcuda.cuProfilerStart() + res = func(*ins) + libcuda.cuProfilerStop() + return res + + def do_bench(self, func: callable, warmup=25, rep=100): ins = self._get_inputs() diff --git a/src/tl/op/builtin.cc b/src/tl/op/builtin.cc index 7bd7d69dcedf..487c460f9561 100644 --- a/src/tl/op/builtin.cc +++ b/src/tl/op/builtin.cc @@ -92,6 +92,10 @@ TIR_DEFINE_TL_BUILTIN(FenceProxyAsyncOp) .set_num_inputs(0) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(SetMaxNReg) + .set_num_inputs(2) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(PackB16Op).set_num_inputs(2).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); } // namespace tl diff --git a/src/tl/op/builtin.h b/src/tl/op/builtin.h index ae83f3a906a9..49e1c21101ee 100644 --- a/src/tl/op/builtin.h +++ b/src/tl/op/builtin.h @@ -146,6 +146,14 @@ const Op& SyncThreadsPartialOp(); */ const Op& FenceProxyAsyncOp(); +/*! + * \brief Set reg hint for warp-specialized branched + * + * SetMaxNRegInc(num_reg, is_inc) + * + */ +const Op& SetMaxNReg(); + } // namespace tl } // namespace tvm diff --git a/src/tl/target/codegen.cc b/src/tl/target/codegen.cc index 6f67322df279..d2253820e2b1 100644 --- a/src/tl/target/codegen.cc +++ b/src/tl/target/codegen.cc @@ -447,8 +447,7 @@ void CodeGenTL::PrintVecElemStore(const std::string& vec, DataType t, int i, ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.lanes() == 2 || t.lanes() == 3) { - stream << vec << '.' << access[i % t.lanes()] << "=" - << "(" << value << ");\n"; + stream << vec << '.' << access[i % t.lanes()] << "=" << "(" << value << ");\n"; } else { std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); stream << ac << "="; @@ -685,6 +684,12 @@ void CodeGenTL::VisitExpr_(const CallNode* op, std::ostream& os) { print_extern_call_stmt(func_name, 2); } else if (op->op.same_as(tl::FenceProxyAsyncOp())) { print_extern_call_stmt("tl::fence_proxy_async"); + } else if (op->op.same_as(tl::SetMaxNReg())) { + this->PrintIndent(); + int nreg = Downcast(op->args[0])->value; + int is_inc = Downcast(op->args[1])->value; + std::string func_name = is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc"; + this->stream << func_name << "<" << std::to_string(nreg) << ">();\n"; } else if (op->op.same_as(tl::PackB16Op())) { os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1]) << ")"; @@ -756,8 +761,7 @@ void CodeGenTL::VisitExpr_(const RampNode* op, std::ostream& os) { PrintType(op->dtype, os); os << "("; for (int i = 0; i < op->lanes; i++) { - os << "(" << PrintExpr(op->base) << ")" - << "+(" << PrintExpr(op->stride) << "*" << i << ")"; + os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")"; if (i != op->lanes - 1) os << ", "; } os << "))"; diff --git a/src/tl/tl_templates/copy_sm90.h b/src/tl/tl_templates/copy_sm90.h index a60c11629574..b724aeeafe03 100644 --- a/src/tl/tl_templates/copy_sm90.h +++ b/src/tl/tl_templates/copy_sm90.h @@ -214,4 +214,14 @@ TL_DEVICE void syncthreads_partial(uint64_t& smem_barrier) { : "r"(smem_int_ptr), "l"(state)); } +template +TL_DEVICE void warpgroup_reg_alloc(){ + asm volatile( "setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); +} + +template +TL_DEVICE void warpgroup_reg_dealloc(){ + asm volatile( "setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); +} + } // namespace tl \ No newline at end of file diff --git a/src/tl/transform/warp_specialized_pipeline.cc b/src/tl/transform/warp_specialized_pipeline.cc index d93a19472cd7..b016fe8a62d5 100644 --- a/src/tl/transform/warp_specialized_pipeline.cc +++ b/src/tl/transform/warp_specialized_pipeline.cc @@ -786,8 +786,15 @@ class WarpSpecializedPipeline : public StmtExprMutator { PrimExpr consumer_thread_extent = thread_iv_->dom->extent; PrimExpr producer_thread_extent = thread_iv_->dom->extent; - // Only need one thread for bulk-copy only case - if (!marker.HasSimtCopy()) producer_thread_extent = 1; + // Need one warp-group for bulk-copy only case + if (!marker.HasSimtCopy()) producer_thread_extent = 128; + + // TODO: estimate the correct reg usage. + auto inc_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), {240, 1})); + auto dec_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), {24, 0})); + + producer_code = SeqStmt({dec_reg_stmt, producer_code}); + consumer_code = SeqStmt({inc_reg_stmt, consumer_code}); producer_code = ThreadIdxRewriter::Rewrite(producer_code, thread_iv_->var, thread_iv_->var - consumer_thread_extent); From 1179b0375edaecf72046bf90668de3c77c75db9d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 27 Jul 2024 15:34:25 +0000 Subject: [PATCH 02/23] add inject proxy fence pass * analysis stmt proxy and inject fence between different proxies, correct on gemm_rs examples * still fail on some configs (bm=bn=64) for mha for h100 --- python/tvm/tl/engine.py | 48 ++++++- python/tvm/tl/transform.py | 10 ++ python/tvm/tl/utils.py | 1 + src/tl/op/bulk_copy.cc | 10 +- src/tl/target/rt_mod.cc | 6 + src/tl/tl_templates/gemm_sm90.h | 2 + src/tl/transform/inject_fence_proxy.cc | 170 +++++++++++++++++++++++++ src/tl/transform/pipeline_planning.cc | 2 +- tl_scripts/gemm_example.py | 5 +- tl_scripts/gemm_rs_example.py | 71 +++++++++++ 10 files changed, 314 insertions(+), 11 deletions(-) create mode 100644 src/tl/transform/inject_fence_proxy.cc create mode 100644 tl_scripts/gemm_rs_example.py diff --git a/python/tvm/tl/engine.py b/python/tvm/tl/engine.py index acde0fb4fbcb..3e010ca942c3 100644 --- a/python/tvm/tl/engine.py +++ b/python/tvm/tl/engine.py @@ -70,6 +70,11 @@ def extrac_params(func: tir.PrimFunc): tensor_types = [relay.TensorType(buffer.shape, buffer.dtype) for buffer in buffers] return tensor_types +@tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) +def tvm_callback_cuda_postproc(code, _): + code = code.replace("""original code""", +"""modified code""") + return code def lower(func): params = extrac_params(func) @@ -79,14 +84,45 @@ def lower(func): target = tvm.target.Target("cuda", target_host) mod = tir.transform.BindTarget(target)(mod) + # print('-'*100 + '\n' + 'after BindTarget\n' + '-'*100) + # print(mod) + mod = tl.transform.FrontendLegalize()(mod) + + # print('-'*100 + '\n' + 'after FrontendLegalize\n' + '-'*100) + # print(mod) + mod = tir.transform.Simplify()(mod) + + # print('-'*100 + '\n' + 'after Simplify\n' + '-'*100) + # print(mod) + mod = tl.transform.LayoutInference()(mod) + + # print('-'*100 + '\n' + 'after LayoutInference\n' + '-'*100) + # print(mod) + mod = tl.transform.LowerTileOp()(mod) + + # print('-'*100 + '\n' + 'after LowerTileOp\n' + '-'*100) + # print(mod) + mod = tir.transform.Simplify()(mod) + # print('-'*100 + '\n' + 'after Simplify\n' + '-'*100) + # print(mod) + if target.arch == "sm_90": mod = tl.transform.WarpSpecializedPipeline()(mod) + + # print('-'*100 + '\n' + 'after WarpSpecializedPipeline\n' + '-'*100) + # print(mod) + + mod = tl.transform.InjectFenceProxy()(mod) + + # print('-'*100 + '\n' + 'after InjectFenceProxy\n' + '-'*100) + # print(mod) + else: mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) mod = tl.transform.PipelinePlanning()(mod) @@ -128,14 +164,22 @@ def lower(func): host_mod = tir.transform.LowerIntrin()(host_mod) host_mod = tir.transform.LowerDeviceStorageAccessInfo()(host_mod) host_mod = tir.transform.CombineContextCall()(host_mod) + # host_code = tvm._ffi.get_global_func("target.build.c")(host_mod, target_host).get_source() + # print("=" * 100) + # print("host code:") + # print("=" * 100) + # print(host_code) host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target) device_mod = tir.transform.Filter(is_device_call)(mod) device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod) device_mod = tir.transform.LowerIntrin()(device_mod) device_mod = tir.transform.Simplify()(device_mod) - # code = tvm._ffi.get_global_func("target.build.tl_debug_codegen")(device_mod, target) - # print(code) + # device_code = tvm._ffi.get_global_func("target.build.tl_debug_codegen")(device_mod, target) + # print("=" * 100) + # print("device code:") + # print("=" * 100) + # print(device_code) device_mod = tvm._ffi.get_global_func("target.build.tl")(device_mod, target) host_mod.import_module(device_mod) diff --git a/python/tvm/tl/transform.py b/python/tvm/tl/transform.py index f19edf8083ea..f040532cbdac 100644 --- a/python/tvm/tl/transform.py +++ b/python/tvm/tl/transform.py @@ -106,3 +106,13 @@ def WarpSpecializedPipeline(): The result pass """ return _ffi_api.WarpSpecializedPipeline() # type: ignore + +def InjectFenceProxy(): + """InjectFenceProxy + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectFenceProxy() # type: ignore \ No newline at end of file diff --git a/python/tvm/tl/utils.py b/python/tvm/tl/utils.py index 7f9218116fa7..9dd11ad4febb 100644 --- a/python/tvm/tl/utils.py +++ b/python/tvm/tl/utils.py @@ -130,6 +130,7 @@ def assert_allclose(self, reference_program: callable, atol: float = 1e-8, rtol: if isinstance(ref_outs, torch.Tensor): ref_outs = [ref_outs] assert len(lib_outs) == len(ref_outs) + # torch.set_printoptions(edgeitems=torch.inf) for lhs, rhs in zip(lib_outs, ref_outs): assert torch.allclose(lhs, rhs, rtol=rtol, atol=atol), (lhs, rhs) diff --git a/src/tl/op/bulk_copy.cc b/src/tl/op/bulk_copy.cc index 33010727ea91..9a6dc099f6b2 100644 --- a/src/tl/op/bulk_copy.cc +++ b/src/tl/op/bulk_copy.cc @@ -228,11 +228,11 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const { } tma_copy = IfThenElse(EQ(T.thread_var, 0), tma_copy); - if (!is_load) { - // TODO: Add this async proxy fence with a seperate pass - auto fence_stmt = Evaluate(Call(DataType::Handle(), FenceProxyAsyncOp(), {})); - tma_copy = SeqStmt({fence_stmt, tma_copy}); - } + // if (!is_load) { + // // TODO: Add this async proxy fence with a seperate pass + // auto fence_stmt = Evaluate(Call(DataType::Handle(), FenceProxyAsyncOp(), {})); + // tma_copy = SeqStmt({fence_stmt, tma_copy}); + // } return tma_copy; } diff --git a/src/tl/target/rt_mod.cc b/src/tl/target/rt_mod.cc index fd02abde7820..9ddae3e4a4b5 100644 --- a/src/tl/target/rt_mod.cc +++ b/src/tl/target/rt_mod.cc @@ -67,6 +67,9 @@ runtime::Module BuildTL(IRModule mod, Target target) { } std::string code = cg.Finish(); + if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) { + code = (*f)(code, target).operator std::string(); + } std::string fmt = "ptx"; std::string ptx; if (const auto* f = Registry::Get("tvm_tl_cuda_compile")) { @@ -93,6 +96,9 @@ String BuildTLDebug(IRModule mod, Target target) { } std::string code = cg.Finish(); + if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) { + code = (*f)(code, target).operator std::string(); + } return String(code); } diff --git a/src/tl/tl_templates/gemm_sm90.h b/src/tl/tl_templates/gemm_sm90.h index e4bc162f9c85..30ee86808100 100644 --- a/src/tl/tl_templates/gemm_sm90.h +++ b/src/tl/tl_templates/gemm_sm90.h @@ -98,6 +98,8 @@ class GemmTensorOp { } static CUTE_DEVICE void body_rs(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) { + // TODO: Move bar.sync out of body_rs + asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n * 32)); const int tid = threadIdx.x; Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), SmemLayoutB{}); auto tiled_mma = diff --git a/src/tl/transform/inject_fence_proxy.cc b/src/tl/transform/inject_fence_proxy.cc new file mode 100644 index 000000000000..51def375bd59 --- /dev/null +++ b/src/tl/transform/inject_fence_proxy.cc @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file inject_fence_proxy.cc + * \brief Inject fence between generic and async proxies (sm90+) + */ + +#include +#include +#include +#include +#include + +#include "../op/builtin.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +enum class Proxy { kGeneric, kAsync, kBoth }; + +class ProxyMarker : public StmtVisitor { + public: + ProxyMarker() = default; + + Proxy GetProxy(const StmtNode* stmt) const { + auto it = map_.find(stmt); + ICHECK(it != map_.end()); + return it->second; + } + + Proxy GetProxy(const Stmt& stmt) const { return GetProxy(stmt.get()); } + + void VisitStmt_(const EvaluateNode* op) final { + Proxy proxy = Proxy::kAsync; + if (auto call = op->value.as()) { + if (call->op.same_as(LDMatrixOp()) || call->op.same_as(STMatrixOp())) { + proxy = Proxy::kGeneric; + } + } + SetProxy(op, proxy); + } + + void VisitStmt_(const BufferStoreNode* op) final { + Proxy proxy = Proxy::kGeneric; + SetProxy(op, proxy); + } + + void VisitStmt_(const SeqStmtNode* op) final { + StmtVisitor::VisitStmt_(op); + auto role = GetProxy(op->seq[0]); + for (auto stmt : op->seq) { + if (role != GetProxy(stmt)) { + role = Proxy::kBoth; + break; + } + } + SetProxy(op, role); + } + + void VisitStmt_(const IfThenElseNode* op) final { + StmtVisitor::VisitStmt_(op); + auto role = GetProxy(op->then_case); + if (op->else_case.defined()) { + auto role_else = GetProxy(op->else_case.value()); + if (role != role_else) role = Proxy::kBoth; + } + SetProxy(op, role); + } + + void VisitStmt_(const BlockRealizeNode* op) final { + StmtVisitor::VisitStmt_(op); + SetProxy(op, GetProxy(op->block)); + } + + template + void HandleBodyStmt(const NodeType* op) { + StmtVisitor::VisitStmt_(op); + SetProxy(op, GetProxy(op->body)); + } + + void VisitStmt_(const ForNode* op) final { HandleBodyStmt(op); } + void VisitStmt_(const LetStmtNode* op) final { HandleBodyStmt(op); } + void VisitStmt_(const AttrStmtNode* op) final { HandleBodyStmt(op); } + void VisitStmt_(const AssertStmtNode* op) final { HandleBodyStmt(op); } + void VisitStmt_(const BlockNode* op) final { HandleBodyStmt(op); } + + + + private: + void SetProxy(const StmtNode* stmt, Proxy proxy) { map_[stmt] = proxy; } + std::unordered_map map_; +}; + + +class InjectFenceProxy : public StmtExprMutator { + public: + static PrimFunc Substitute(PrimFunc f) { + auto T = InjectFenceProxy(); + f.CopyOnWrite()->body = T(f->body); + return f; + } + + private: + Proxy get_generic_proxy(const Stmt& stmt) { + auto marker = ProxyMarker(); + marker(stmt); + return marker.GetProxy(stmt); + } + + Stmt VisitStmt_(const SeqStmtNode* op) final { + ICHECK(op->seq.size() > 0); + Array new_body; + Proxy cur_proxy, prev_proxy; + auto fence_stmt = Evaluate(Call(DataType::Handle(), FenceProxyAsyncOp(), {})); + prev_proxy = get_generic_proxy(op->seq[0]); + new_body.push_back(VisitStmt(op->seq[0])); + if (op->seq.size() > 1) { + for (int i = 1; i < static_cast(op->seq.size()); i++) { + cur_proxy = get_generic_proxy(op->seq[i]); + if (cur_proxy == Proxy::kAsync && prev_proxy == Proxy::kGeneric) { + new_body.push_back(fence_stmt); + } + new_body.push_back(VisitStmt(op->seq[i])); + prev_proxy = cur_proxy; + } + } + ICHECK(new_body.size() > 0); + return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)); + } + + // Stmt VisitStmt_(const ForNode* op) final { + // std::cout << "ForNode:" << op->body->GetTypeKey() << std::endl; + // return StmtExprMutator::VisitStmt_(op); + // } + + InjectFenceProxy() = default; +}; + +using namespace tir::transform; + +tvm::transform::Pass InjectFenceProxy() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return InjectFenceProxy::Substitute(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.InjectFenceProxy", {}); +} + +TVM_REGISTER_GLOBAL("tl.InjectFenceProxy").set_body_typed(InjectFenceProxy); + +} // namespace tl +} // namespace tvm diff --git a/src/tl/transform/pipeline_planning.cc b/src/tl/transform/pipeline_planning.cc index ff480e428bde..c2c2f389afae 100644 --- a/src/tl/transform/pipeline_planning.cc +++ b/src/tl/transform/pipeline_planning.cc @@ -66,7 +66,7 @@ class PipelinePlanner : public StmtExprMutator { substituter.buffer_data_to_buffer_.Set(buffer->data, buffer); } auto target = f->GetAttr(tvm::attr::kTarget); - ICHECK(target.defined()) << "Layout_Inference: Require the target attribute"; + ICHECK(target.defined()) << "Pipeline_Planning: Require the target attribute"; substituter.target_ = target.value(); return substituter.VisitStmt(f->body); } diff --git a/tl_scripts/gemm_example.py b/tl_scripts/gemm_example.py index 7f9265b8f379..5fb1375b4a22 100644 --- a/tl_scripts/gemm_example.py +++ b/tl_scripts/gemm_example.py @@ -60,9 +60,8 @@ def main_hopper(A: T.Buffer((M, K), dtype), B: T.Buffer((K, N), dtype), C: T.Buf T.copy(A[(bx % T.ceildiv(M, swizzle_M)) * swizzle_M + bz * block_M, k * block_K], A_shared) T.copy(B[k * block_K, T.floordiv(bx, T.ceildiv(N, swizzle_N)) * swizzle_N + by * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) - if targetHopper: - T.copy(C_local, C_shared) - T.copy(C_shared, C[(bx % T.ceildiv(M, swizzle_M)) * swizzle_M + bz * block_M, T.floordiv(bx, T.ceildiv(N, swizzle_N)) * swizzle_N + by * block_N]) + T.copy(C_local, C_shared) + T.copy(C_shared, C[(bx % T.ceildiv(M, swizzle_M)) * swizzle_M + bz * block_M, T.floordiv(bx, T.ceildiv(N, swizzle_N)) * swizzle_N + by * block_N]) if targetHopper: return main_hopper diff --git a/tl_scripts/gemm_rs_example.py b/tl_scripts/gemm_rs_example.py new file mode 100644 index 000000000000..362081a07393 --- /dev/null +++ b/tl_scripts/gemm_rs_example.py @@ -0,0 +1,71 @@ +import argparse +import torch +from tvm import tl +import tvm.tl.language as T +from tvm.tl.autotuner import * +import itertools + +targetHopper = True +swizzle_M = 4096 +swizzle_N = 4096 + + +def ref_program(A, B): + return A @ B + +def get_configs(): + block_M = [64] + block_N = [64] + block_K = [64] + num_stages = [1] + thread_num = [128] + _configs = list(itertools.product(block_M, block_N, block_K, num_stages, thread_num)) + + configs = [ + {'block_M': c[0], 'block_N': c[1], 'block_K': c[2], 'num_stages': c[3], 'thread_num': c[4]} + for c in _configs + ] + return configs + +def matmul(M, N, K): + + @autotune(configs=get_configs(), keys=['block_M', 'block_N', 'block_K', 'num_stages', 'thread_num'], warmup=10, rep=5) + @jit(out_idx=[2], supply_type=tl.TensorSupplyType.Integer, ref_prog=ref_program) + def kernel(block_M = None, block_N = None, block_K = None, num_stages = None, thread_num = None): + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def main(A: T.Buffer((M, K), dtype), B: T.Buffer((K, N), dtype), C: T.Buffer((M, N), dtype)): # type: ignore + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + A_local = T.alloc_fragment((block_M, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_local) + T.gemm(A_local, B_shared, C_local) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + return kernel() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--m', type=int, default=8192, help='M') + parser.add_argument('--n', type=int, default=8192, help='N') + parser.add_argument('--k', type=int, default=8192, help='K') + args = parser.parse_args() + M, N, K = args.m, args.n, args.k + total_flops = 2 * M * N * K + best_latency, best_config, ref_latency = matmul(M, N, K) + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}") From db32751aba3e80c03a6593816553191727f65553 Mon Sep 17 00:00:00 2001 From: cy Date: Thu, 8 Aug 2024 10:12:49 +0800 Subject: [PATCH 03/23] add pipeline_transform --- tl_scripts/pipeline_transform.py | 150 +++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 tl_scripts/pipeline_transform.py diff --git a/tl_scripts/pipeline_transform.py b/tl_scripts/pipeline_transform.py new file mode 100644 index 000000000000..5d2ab185c7aa --- /dev/null +++ b/tl_scripts/pipeline_transform.py @@ -0,0 +1,150 @@ +from typing import Any, List, Dict, Optional, Tuple, Union + +class Edge: + def __init__(self, src_node: 'Node', dst_node: 'Node', src_id: int, dst_id: int, diff: int): + self.diff = diff + self.src_node = src_node + self.dst_node = dst_node + self.src_id = src_id + self.dst_id = dst_id + + def __repr__(self) -> str: + return "" + self.dst_node.name + ">" + +class DataDependencyEdge(Edge): + def __init__(self, src_node: 'Node', dst_node: 'Node', src_id: int, dst_id: int, diff: int): + super().__init__(src_node, dst_node, src_id, dst_id, diff) + + def __repr__(self) -> str: + return "" + self.dst_node.name + ">" + +class AntiDependencyEdge(Edge): + def __init__(self, src_node: 'Node', dst_node: 'Node', src_id: int, dst_id: int, diff: int): + super().__init__(src_node, dst_node, src_id, dst_id, diff) + + def __repr__(self) -> str: + return "" + self.dst_node.name + ">" + +class CondDependencyEdge(Edge): + def __init__(self, src_node: 'Node', dst_node: 'Node', src_id: int, dst_id: int, diff: int): + super().__init__(src_node, dst_node, src_id, dst_id, diff) + + def __repr__(self) -> str: + return "" + self.dst_node.name + ">" + +class Node: + def __init__(self, inputs: List[Union[Tuple['Node', int], 'Node', None]], name: str, sync_type: str): + self.name = name + self._out_edges = [] + self._in_edges = [] + self._shapes = [] + self._dtypes = [] + self._tag = {} + self._sync_type = sync_type + assert sync_type in ['sync', 'async'], 'invaild sync_type!' + + for i, node in enumerate(inputs): + if node is None: + inputs[i] = PlaceHolderNode() + + for dst_id, n in enumerate(inputs): + if isinstance(n, Node): + n = (n, 0) + assert(len(n) == 2) + src_node, src_id = n[0], n[1] + edge = DataDependencyEdge(src_node, self, src_id, dst_id, 0) + self._in_edges.append(edge) + src_node._out_edges.append(edge) + edge = AntiDependencyEdge(self, src_node, 0, 0, -1) + self._out_edges.append(edge) + src_node._in_edges.append(edge) + + @property + def inputs(self) -> List[Edge]: + return self._in_edges + + @property + def outputs(self) -> List[Edge]: + return self._out_edges + + def set_inputs(self, i: int, edge: Edge): + assert i < len(self._in_edges) + self._in_edges[i] = edge + + def set_outputs(self, i: int, edge: Edge): + assert i < len(self._out_edges) + self._out_edges[i] = edge + + def get_shape(self, id: int = 0) -> List[int]: + return self._shapes[id] + + def set_shape(self, shape: List[int], id=0, overwrite=False) -> None: + if len(self._shapes) <= id: + self._shapes.extend([None for _ in range(id - len(self._shapes) + 1)]) + elif self._shapes[id] is not None and not overwrite: + assert self._shapes[id] == list(map(int, shape)), (self._shapes, list(map(int, shape))) + self._shapes[id] = list(map(int, shape)) + + def is_placeholder(self): + return False + + def is_output(self): + return False + + def add_tag(self, k: str, v: Any = True) -> None: + self._tag[k] = v + + def get_tag(self, k: str) -> Any: + if k not in self._tag: + return None + return self._tag[k] + + def num_outputs(self) -> int: + if len(self.outputs) == 0: + return 0 + return max([e.src_id for e in self.outputs]) + 1 + + def get_ir(self) -> str: + raise NotImplementedError() + + def __repr__(self) -> str: + return "" + +class PlaceHolderNode(Node): + def __init__(self, name=""): + super().__init__([], "PlaceHolder " + name, "sync") + + def is_placeholder(self): + return True + + def get_ir(self) -> str: + return "placeholder" + +class OutputNode(Node): + def __init__(self, node, id=0): + super().__init__([(node, id)], "Output ", "sync") + # self.set_shape(node.get_shape(id)) + # self.set_dtype(node.get_dtype(id)) + + def is_output(self): + return True + + def get_ir(self) -> str: + return "output" + + +def transform(nodes: List['Node'], stream_num: int, cond_deps) -> List['Node']: + return nodes + +# mha Example +if __name__ == "__main__": + loadk = Node(inputs=[None], name="loadk", sync_type='async') + mma0 = Node(inputs=[loadk], name="mma0", sync_type='async') + loadv = Node(inputs=[None], name="loadv", sync_type='async') + softmax = Node(inputs=[mma0], name="softmax", sync_type='sync') + mma1 = Node(inputs=[loadv, softmax], name="mma1", sync_type='async') + out = OutputNode(mma1) + ordered_nodes = [loadk, mma0, loadv, softmax, mma1, out] + + for edge in softmax.inputs: + print(edge) From 9d88410efb84049b7cc71dac161e8ae44322603f Mon Sep 17 00:00:00 2001 From: cy Date: Tue, 13 Aug 2024 21:37:55 +0800 Subject: [PATCH 04/23] add tl_pipeline, level0->level1->level2 lower process --- tl_pipeline/generate_plan.py | 173 ++++++++++++++++++++++ tl_pipeline/graph.py | 237 ++++++++++++++++++++++++++++++ tl_pipeline/main.py | 30 ++++ tl_pipeline/pipeline_transform.py | 89 +++++++++++ tl_scripts/pipeline_transform.py | 150 ------------------- 5 files changed, 529 insertions(+), 150 deletions(-) create mode 100644 tl_pipeline/generate_plan.py create mode 100644 tl_pipeline/graph.py create mode 100644 tl_pipeline/main.py create mode 100644 tl_pipeline/pipeline_transform.py delete mode 100644 tl_scripts/pipeline_transform.py diff --git a/tl_pipeline/generate_plan.py b/tl_pipeline/generate_plan.py new file mode 100644 index 000000000000..6c138f01aed1 --- /dev/null +++ b/tl_pipeline/generate_plan.py @@ -0,0 +1,173 @@ +from graph import * +from itertools import permutations, product + +class Instruction: + def __init__(self, node: Node, iter: int = -1): + self.instr = node + self.iter = iter + + def __repr__(self) -> str: + return repr(self.instr) + + def __eq__(self, other): + if type(self) is not type(other): + return False + return self.instr == other.instr and self.iter == other.iter + + def __gt__(self, other): + if type(self) is not type(other): + return False + return self.instr == other.instr and self.iter > other.iter + + def __lt__(self, other): + if type(self) is not type(other): + return False + return self.instr == other.instr and self.iter < other.iter + + def __ge__(self, other): + if type(self) is not type(other): + return False + return self.instr == other.instr and self.iter >= other.iter + + def __le__(self, other): + if type(self) is not type(other): + return False + return self.instr == other.instr and self.iter <= other.iter + + def __hash__(self): + return hash((self.instr, self.iter)) + +class Issue(Instruction): + def __init__(self, node: Node, iter: int = -1): + super().__init__(node, iter) + + def __repr__(self) -> str: + return super().__repr__() + f".issue({self.iter})" + + def __hash__(self): + return hash((super().__hash__(), 'Issue')) + +class Wait(Instruction): + def __init__(self, node: Node, iter: int = -1): + super().__init__(node, iter) + + def __repr__(self) -> str: + return super().__repr__() + f".wait({self.iter})" + + def __hash__(self): + return hash((super().__hash__(), 'Wait')) + +class Plan: + def __init__(self, instrs: List[Instruction], graph_id: int = -1) -> None: + self.instrs = instrs + self.graph_id = graph_id + + def set_graph_id(self, graph_id: int): + self.graph_id = graph_id + + def __repr__(self) -> str: + s = "" + for instr in self.instrs: + s += repr(instr) + s += "\n" + return s + + def __eq__(self, other): + if type(self) is not type(other): + return False + if len(self.instrs) != len(other.instrs): + return False + for i in range(len(self.instrs)): + if self.instrs[i] != other.instrs[i]: + return False + return True + + def __hash__(self): + return hash(tuple(self.instrs)) + +def is_valid_graph(graph: 'nx.DiGraph') -> bool: + if nx.negative_edge_cycle(graph): + return False + return True + +def get_plan(graph: 'nx.DiGraph', updated_graph: 'nx.DiGraph', ordered_nodes: List['Node']) -> Union[List['Instruction'], None]: + if not is_valid_graph(updated_graph): + return None + instrs = [] + + for node in ordered_nodes: + iteration = nx.bellman_ford_path_length(updated_graph, ordered_nodes[0], node) + for pre_node in list(graph.predecessors(node)): + if pre_node._sync_type == "sync": + continue + wait_instr = Wait(pre_node, iteration - graph[pre_node][node]['weight']) + flag = True + for instr in instrs: + # Unnecessary redundant waiting + if instr >= wait_instr: + flag = False + break + if flag: + instrs.append(wait_instr) + instrs.append(Issue(node, iteration)) + # for r in instrs: + # print(r) + # print("="*100) + return Plan(instrs) + +def update_base_on_order(graph: 'nx.DiGraph', ordered_nodes: List['Node']) -> 'nx.DiGraph': + g = graph.copy() + for s, d, data in g.edges(data=True): + if ordered_nodes.index(s) > ordered_nodes.index(d): + g[s][d]['weight'] = data['weight'] - 1 + return g + +def generate(graph: 'nx.DiGraph', topo_ordered_nodes: List['Node']) -> Union[List[List['Instruction']], None]: + # We only support tile-graph with 1 output node + # If there are 2 output nodes in tile-graph, you can add a new output node and add edge to this new node + # topo_ordered_nodes = list(nx.topological_sort(graph)) + output_nodes = [] + for node in graph.nodes(): + is_output = True + for _, _, data in graph.out_edges(node, data=True): + if data.get('type') == 'data': + is_output = False + break + if is_output: + output_nodes.append(node) + assert len(output_nodes) == 1, "Error: number of output_node is not 1." + results = [] + all_orders = list(permutations(range(len(graph.nodes) - 1))) + # print(all_orders) + + for order in all_orders: + ordered_nodes = [topo_ordered_nodes[-1]] + for i in range(len(graph.nodes) - 1): + ordered_nodes.append(topo_ordered_nodes[order[i]]) + updated_graph = update_base_on_order(graph, ordered_nodes) + plan = get_plan(graph, updated_graph, ordered_nodes) + if plan is not None: + results.append(plan) + return results + +if __name__ == "__main__": + _, graph = load_model("inputs.json") + topo_ordered_nodes = list(nx.topological_sort(graph)) + print("topo_ordered_nodes:", topo_ordered_nodes) + + v0, v3, v1, v2, v4 = topo_ordered_nodes + graph.add_edge(v4, v1, weight=2, type="control") + graph.add_edge(v4, v3, weight=2, type="control") + graph.add_edge(v1, v0, weight=1, type="control") + + print(graph) + print("Nodes:", graph.nodes(data=True)) + print("Edges:", graph.edges(data=True)) + + plans = generate(graph, topo_ordered_nodes) + + for i, plan in enumerate(plans): + print("-" * 100) + print(f"Plan {i}:") + for instr in plan: + print(instr) diff --git a/tl_pipeline/graph.py b/tl_pipeline/graph.py new file mode 100644 index 000000000000..476c5ad7f3e9 --- /dev/null +++ b/tl_pipeline/graph.py @@ -0,0 +1,237 @@ +from typing import Any, List, Dict, Optional, Tuple, Union +import heapq + +class Edge: + def __init__(self, src_node: 'Node', dst_node: 'Node', src_id: int, dst_id: int, w: int): + self.w = w + self.src_node = src_node + self.dst_node = dst_node + self.src_id = src_id + self.dst_id = dst_id + + def __repr__(self) -> str: + return "" + self.dst_node.name + ">" + +class DataDependencyEdge(Edge): + def __init__(self, src_node: 'Node', dst_node: 'Node', src_id: int, dst_id: int, w: int): + super().__init__(src_node, dst_node, src_id, dst_id, w) + + def __repr__(self) -> str: + return "" + self.dst_node.name + ">" + +class CondDependencyEdge(Edge): + def __init__(self, src_node: 'Node', dst_node: 'Node', src_id: int, dst_id: int, w: int): + super().__init__(src_node, dst_node, src_id, dst_id, w) + + def __repr__(self) -> str: + return "" + self.dst_node.name + ">" + +class Node: + def __init__(self, node_id: int, inputs: List[Union[Tuple['Node', int], 'Node', None]], name: str, sync_type: str): + self.id = node_id + self.name = name + self._out_edges = [] + self._in_edges = [] + self._shapes = [] + self._dtypes = [] + self._tag = {} + self._sync_type = sync_type + assert sync_type in ['sync', 'async'], 'invaild sync_type!' + + for i, node in enumerate(inputs): + if node is None: + inputs[i] = PlaceHolderNode() + + for dst_id, n in enumerate(inputs): + if isinstance(n, Node): + n = (n, 0) + assert(len(n) == 2) + src_node, src_id = n[0], n[1] + edge = DataDependencyEdge(src_node, self, src_id, dst_id, 0) + self._in_edges.append(edge) + src_node._out_edges.append(edge) + + @property + def inputs(self) -> List[Edge]: + return self._in_edges + + @property + def outputs(self) -> List[Edge]: + return self._out_edges + + def set_inputs(self, i: int, edge: Edge): + assert i < len(self._in_edges) + self._in_edges[i] = edge + + def set_outputs(self, i: int, edge: Edge): + assert i < len(self._out_edges) + self._out_edges[i] = edge + + def get_shape(self, id: int = 0) -> List[int]: + return self._shapes[id] + + def set_shape(self, shape: List[int], id=0, overwrite=False) -> None: + if len(self._shapes) <= id: + self._shapes.extend([None for _ in range(id - len(self._shapes) + 1)]) + elif self._shapes[id] is not None and not overwrite: + assert self._shapes[id] == list(map(int, shape)), (self._shapes, list(map(int, shape))) + self._shapes[id] = list(map(int, shape)) + + def is_placeholder(self): + return False + + def is_output(self): + return False + + def add_tag(self, k: str, v: Any = True) -> None: + self._tag[k] = v + + def get_tag(self, k: str) -> Any: + if k not in self._tag: + return None + return self._tag[k] + + def num_outputs(self) -> int: + if len(self.outputs) == 0: + return 0 + return max([e.src_id for e in self.outputs]) + 1 + + def get_ir(self) -> str: + raise NotImplementedError() + + def __repr__(self) -> str: + # return "" + return self.name + +class PlaceHolderNode(Node): + def __init__(self, name=""): + super().__init__([], "PlaceHolder " + name, "sync") + + def is_placeholder(self): + return True + + def get_ir(self) -> str: + return "placeholder" + +class OutputNode(Node): + def __init__(self, node, id=0): + super().__init__([(node, id)], "Output ", "sync") + # self.set_shape(node.get_shape(id)) + # self.set_dtype(node.get_dtype(id)) + + def is_output(self): + return True + + def get_ir(self) -> str: + return "output" + +class Graph: + def __init__(self, nodes: List['Node']): + self.nodes = nodes + self.node_count = len(self.nodes) + self.dist = [[float('inf')] * self.node_count for _ in range(self.node_count)] + self.init_dist() + self.edge_count = 0 + for node in self.nodes: + self.edge_count += len(node.outputs) + + print("dist:") + print(self.dist) + print("edge_count:") + print(self.edge_count) + + def init_dist(self): + for i in range(self.node_count): + self.dist[i][i] = 0 + for node in self.nodes: + for edge in node.outputs: + self.add_edge(node, edge.dst_node, edge.w) + + def exist_edge(self, n0: 'Node', n1: 'Node') -> bool: + assert n0 in self.nodes, "n0 not in graph!" + assert n1 in self.nodes, "n1 not in graph!" + for edge in n0.outputs: + if n1 == edge.dst_node: + return True + return False + + def update_dist(self, u, v, w): + dist = self.dist.copy() + if dist[u][v] > w: + dist[u][v] = w + + for i in range(self.node_count): + for j in range(self.node_count): + if dist[i][v] > dist[i][u] + w: + dist[i][v] = dist[i][u] + w + if dist[u][j] > w + dist[v][j]: + dist[u][j] = w + dist[v][j] + if dist[i][j] > dist[i][u] + w + dist[v][j]: + dist[i][j] = dist[i][u] + w + dist[v][j] + + # Detect negetive ring + for i in range(self.node_count): + for j in range(self.node_count): + if dist[i][v] > dist[i][u] + w: + return None + if dist[u][j] > w + dist[v][j]: + return None + if dist[i][j] > dist[i][u] + w + dist[v][j]: + return None + return dist + + + def add_edge(self, n0, n1, w) -> bool: + self.dist[n0.id][n1.id] = w + dist = self.update_dist(n0.id, n1.id, w) + if dist is None: + print("Invalid, negetive ring detected.") + return False # Invalid + self.dist = dist + print("updated dist:") + print(self.dist) + return True + + def del_edge(self, n0, n1): + pass + + def check_legality(self) -> bool: + return True + +def print_graph(nodes: List['Node']): + for node in nodes: + print(node) + for edge in node.inputs: + print(edge) + print('-'*100) + +def get_path_value(n0: 'Node', n1: 'Node') -> int: + + return + + +import networkx as nx +from graph import * +import json + +def load_model(fname: str) -> List[Node]: + graph = nx.DiGraph() + with open(fname) as f: + a = json.load(f) + node_map = {item[0] : None for item in a} + ordered_nodes = [] + for node_id, name, sync_type, is_output, inputs in a: + input_list = [] + for src_node, src_id in inputs: + if src_node not in node_map: + input_list.append(None) + else: + assert node_map[src_node] is not None, "Detected ring in topo order {}->{} !".format(src_node, node_id) + input_list.append([node_map[src_node], src_id]) + node = Node(node_id, input_list, name, sync_type) + for src_node, _ in inputs: + assert node_map[src_node] is not None + graph.add_edge(node_map[src_node], node, weight=0, type="data") + node_map[node_id] = node + ordered_nodes.append(node) + return ordered_nodes, graph \ No newline at end of file diff --git a/tl_pipeline/main.py b/tl_pipeline/main.py new file mode 100644 index 000000000000..78e207d43412 --- /dev/null +++ b/tl_pipeline/main.py @@ -0,0 +1,30 @@ +from pipeline_transform import * +from generate_plan import * + +if __name__ == "__main__": + _, graph = load_model("inputs.json") + print(graph) + print("Nodes:", graph.nodes(data=True)) + print("Edges:", graph.edges(data=True)) + + topo_ordered_nodes = list(nx.topological_sort(graph)) + print("topo_ordered_nodes:", topo_ordered_nodes) + + results = [] + add_backward_edges(graph, graph.copy(), list(reversed(topo_ordered_nodes)), stream=2, cur_id=0, dst_id=0, result=results) + plan_id = 0 + plans = [] + for i, g in enumerate(results): + print(f"Graph {i + 1}:") + print(g.edges(data=True)) + for p in generate(g, topo_ordered_nodes): + p.set_graph_id(i + 1) + plans.append(p) + + plans = list(set(plans)) + print(f"Plans: {len(plans)}") + for i, plan in enumerate(plans): + print(f"Plan {i + 1}:") + print("-" * 100) + print("Graph:", plan.graph_id) + print(plan) diff --git a/tl_pipeline/pipeline_transform.py b/tl_pipeline/pipeline_transform.py new file mode 100644 index 000000000000..00787ae9e578 --- /dev/null +++ b/tl_pipeline/pipeline_transform.py @@ -0,0 +1,89 @@ +import json +from graph import * + +import networkx as nx +from graph import * +import json +# from pipeline_transform import load_model + + + +def is_edge_valid(graph: 'nx.DiGraph', u: 'Node', v: 'Node', weight: int) -> bool: + if graph.has_edge(u, v): + return False + if nx.negative_edge_cycle(graph): + return False + # if nx.has_path(graph, u, v): + g = graph.copy() + g.add_edge(u, v, weight=weight) + shortest_paths = dict(nx.all_pairs_bellman_ford_path_length(g, weight='weight')) + + # print("-"*100) + # print("Edges:", g.edges(data=True)) + for s, d, data in g.edges(data=True): + edge_weight = data['weight'] + s_to_d_dist = shortest_paths[s][d] + + # print("edge_weight:", edge_weight, "s:", s, "d:", d) + if edge_weight > s_to_d_dist: + # print("is_edge_valid failed 0") + return False + + if edge_weight == s_to_d_dist: + if len(list(nx.all_shortest_paths(g, source=s, target=d, weight='weight'))) > 1: + # print("is_edge_valid failed 1") + # print("path:", path) + return False + return True + +def is_graph_valid(graph: 'nx.DiGraph', original_graph: 'nx.DiGraph'): + output_nodes = [node for node in original_graph.nodes() if original_graph.out_degree(node) == 0] + input_nodes = [node for node in original_graph.nodes() if original_graph.in_degree(node) == 0] + for output_node in output_nodes: + for input_node in input_nodes: + if not nx.has_path(graph, output_node, input_node): + return False + return True + +def add_backward_edges(graph: 'nx.DiGraph', original_graph: 'nx.DiGraph', ordered_nodes: List['Node'], stream: int, cur_id: int, dst_id: int, result: List['nx.DiGraph']): + if cur_id == len(ordered_nodes): + if is_graph_valid(graph, original_graph): + result.append(graph.copy()) + return + + if ordered_nodes[cur_id]._sync_type == "sync": + add_backward_edges(graph, original_graph, ordered_nodes, stream, cur_id + 1, 0, result) + return + + if dst_id == len(ordered_nodes): + add_backward_edges(graph, original_graph, ordered_nodes, stream, cur_id + 1, 0, result) + return + + u = ordered_nodes[cur_id] + v = ordered_nodes[dst_id] + if u in list(nx.ancestors(original_graph, v)) or u == v: + add_backward_edges(graph, original_graph, ordered_nodes, stream, cur_id, dst_id + 1, result) + return + + add_backward_edges(graph, original_graph, ordered_nodes, stream, cur_id, dst_id + 1, result) + for w in range(1, stream + 1): + if is_edge_valid(graph, u, v, w): + graph.add_edge(u, v, weight=w, type="control") + add_backward_edges(graph, original_graph, ordered_nodes, stream, cur_id, dst_id + 1, result) + graph.remove_edge(u, v) + + +if __name__ == "__main__": + _, graph = load_model("inputs.json") + print(graph) + print("Nodes:", graph.nodes(data=True)) + print("Edges:", graph.edges(data=True)) + + ordered_nodes = list(nx.topological_sort(graph)) + print("ordered_nodes:", ordered_nodes) + + results = [] + add_backward_edges(graph, graph.copy(), list(reversed(ordered_nodes)), stream=2, cur_id=0, dst_id=0, result=results) + for i, g in enumerate(results): + print(f"Graph {i+1}:") + print(g.edges(data=True)) \ No newline at end of file diff --git a/tl_scripts/pipeline_transform.py b/tl_scripts/pipeline_transform.py deleted file mode 100644 index 5d2ab185c7aa..000000000000 --- a/tl_scripts/pipeline_transform.py +++ /dev/null @@ -1,150 +0,0 @@ -from typing import Any, List, Dict, Optional, Tuple, Union - -class Edge: - def __init__(self, src_node: 'Node', dst_node: 'Node', src_id: int, dst_id: int, diff: int): - self.diff = diff - self.src_node = src_node - self.dst_node = dst_node - self.src_id = src_id - self.dst_id = dst_id - - def __repr__(self) -> str: - return "" + self.dst_node.name + ">" - -class DataDependencyEdge(Edge): - def __init__(self, src_node: 'Node', dst_node: 'Node', src_id: int, dst_id: int, diff: int): - super().__init__(src_node, dst_node, src_id, dst_id, diff) - - def __repr__(self) -> str: - return "" + self.dst_node.name + ">" - -class AntiDependencyEdge(Edge): - def __init__(self, src_node: 'Node', dst_node: 'Node', src_id: int, dst_id: int, diff: int): - super().__init__(src_node, dst_node, src_id, dst_id, diff) - - def __repr__(self) -> str: - return "" + self.dst_node.name + ">" - -class CondDependencyEdge(Edge): - def __init__(self, src_node: 'Node', dst_node: 'Node', src_id: int, dst_id: int, diff: int): - super().__init__(src_node, dst_node, src_id, dst_id, diff) - - def __repr__(self) -> str: - return "" + self.dst_node.name + ">" - -class Node: - def __init__(self, inputs: List[Union[Tuple['Node', int], 'Node', None]], name: str, sync_type: str): - self.name = name - self._out_edges = [] - self._in_edges = [] - self._shapes = [] - self._dtypes = [] - self._tag = {} - self._sync_type = sync_type - assert sync_type in ['sync', 'async'], 'invaild sync_type!' - - for i, node in enumerate(inputs): - if node is None: - inputs[i] = PlaceHolderNode() - - for dst_id, n in enumerate(inputs): - if isinstance(n, Node): - n = (n, 0) - assert(len(n) == 2) - src_node, src_id = n[0], n[1] - edge = DataDependencyEdge(src_node, self, src_id, dst_id, 0) - self._in_edges.append(edge) - src_node._out_edges.append(edge) - edge = AntiDependencyEdge(self, src_node, 0, 0, -1) - self._out_edges.append(edge) - src_node._in_edges.append(edge) - - @property - def inputs(self) -> List[Edge]: - return self._in_edges - - @property - def outputs(self) -> List[Edge]: - return self._out_edges - - def set_inputs(self, i: int, edge: Edge): - assert i < len(self._in_edges) - self._in_edges[i] = edge - - def set_outputs(self, i: int, edge: Edge): - assert i < len(self._out_edges) - self._out_edges[i] = edge - - def get_shape(self, id: int = 0) -> List[int]: - return self._shapes[id] - - def set_shape(self, shape: List[int], id=0, overwrite=False) -> None: - if len(self._shapes) <= id: - self._shapes.extend([None for _ in range(id - len(self._shapes) + 1)]) - elif self._shapes[id] is not None and not overwrite: - assert self._shapes[id] == list(map(int, shape)), (self._shapes, list(map(int, shape))) - self._shapes[id] = list(map(int, shape)) - - def is_placeholder(self): - return False - - def is_output(self): - return False - - def add_tag(self, k: str, v: Any = True) -> None: - self._tag[k] = v - - def get_tag(self, k: str) -> Any: - if k not in self._tag: - return None - return self._tag[k] - - def num_outputs(self) -> int: - if len(self.outputs) == 0: - return 0 - return max([e.src_id for e in self.outputs]) + 1 - - def get_ir(self) -> str: - raise NotImplementedError() - - def __repr__(self) -> str: - return "" - -class PlaceHolderNode(Node): - def __init__(self, name=""): - super().__init__([], "PlaceHolder " + name, "sync") - - def is_placeholder(self): - return True - - def get_ir(self) -> str: - return "placeholder" - -class OutputNode(Node): - def __init__(self, node, id=0): - super().__init__([(node, id)], "Output ", "sync") - # self.set_shape(node.get_shape(id)) - # self.set_dtype(node.get_dtype(id)) - - def is_output(self): - return True - - def get_ir(self) -> str: - return "output" - - -def transform(nodes: List['Node'], stream_num: int, cond_deps) -> List['Node']: - return nodes - -# mha Example -if __name__ == "__main__": - loadk = Node(inputs=[None], name="loadk", sync_type='async') - mma0 = Node(inputs=[loadk], name="mma0", sync_type='async') - loadv = Node(inputs=[None], name="loadv", sync_type='async') - softmax = Node(inputs=[mma0], name="softmax", sync_type='sync') - mma1 = Node(inputs=[loadv, softmax], name="mma1", sync_type='async') - out = OutputNode(mma1) - ordered_nodes = [loadk, mma0, loadv, softmax, mma1, out] - - for edge in softmax.inputs: - print(edge) From a350f94405f385e409c24efd76a27db1620a5afa Mon Sep 17 00:00:00 2001 From: cy Date: Tue, 13 Aug 2024 22:22:18 +0800 Subject: [PATCH 05/23] update --- tl_pipeline/generate_plan.py | 7 +++++++ tl_pipeline/main.py | 1 - 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tl_pipeline/generate_plan.py b/tl_pipeline/generate_plan.py index 6c138f01aed1..05e63704ad53 100644 --- a/tl_pipeline/generate_plan.py +++ b/tl_pipeline/generate_plan.py @@ -107,6 +107,13 @@ def get_plan(graph: 'nx.DiGraph', updated_graph: 'nx.DiGraph', ordered_nodes: Li if instr >= wait_instr: flag = False break + # Example: + # mma.wait(0) + # ... + # mma.wait(1) + # Then mma.wait(0) can be removed because this wait is done in the previous iteration + if instr < wait_instr: + instrs.remove(instr) if flag: instrs.append(wait_instr) instrs.append(Issue(node, iteration)) diff --git a/tl_pipeline/main.py b/tl_pipeline/main.py index 78e207d43412..d02f548a9814 100644 --- a/tl_pipeline/main.py +++ b/tl_pipeline/main.py @@ -12,7 +12,6 @@ results = [] add_backward_edges(graph, graph.copy(), list(reversed(topo_ordered_nodes)), stream=2, cur_id=0, dst_id=0, result=results) - plan_id = 0 plans = [] for i, g in enumerate(results): print(f"Graph {i + 1}:") From ba425baf81d993320fb740a03a4a756569566ed9 Mon Sep 17 00:00:00 2001 From: cy Date: Wed, 14 Aug 2024 00:54:35 +0800 Subject: [PATCH 06/23] add add_forward_edges, prune based on hw an buffer memory layer --- tl_pipeline/graph.py | 8 +++-- tl_pipeline/main.py | 4 +-- tl_pipeline/pipeline_transform.py | 59 +++++++++++++++++++++++++++---- 3 files changed, 59 insertions(+), 12 deletions(-) diff --git a/tl_pipeline/graph.py b/tl_pipeline/graph.py index 476c5ad7f3e9..adc23f59e0dc 100644 --- a/tl_pipeline/graph.py +++ b/tl_pipeline/graph.py @@ -27,7 +27,7 @@ def __repr__(self) -> str: return "" + self.dst_node.name + ">" class Node: - def __init__(self, node_id: int, inputs: List[Union[Tuple['Node', int], 'Node', None]], name: str, sync_type: str): + def __init__(self, node_id: int, inputs: List[Union[Tuple['Node', int], 'Node', None]], name: str, sync_type: str, hw_usage: List[str], buffer_layer: List[str]): self.id = node_id self.name = name self._out_edges = [] @@ -36,6 +36,8 @@ def __init__(self, node_id: int, inputs: List[Union[Tuple['Node', int], 'Node', self._dtypes = [] self._tag = {} self._sync_type = sync_type + self.hw_usage = hw_usage + self.buffer_layer = buffer_layer assert sync_type in ['sync', 'async'], 'invaild sync_type!' for i, node in enumerate(inputs): @@ -220,7 +222,7 @@ def load_model(fname: str) -> List[Node]: a = json.load(f) node_map = {item[0] : None for item in a} ordered_nodes = [] - for node_id, name, sync_type, is_output, inputs in a: + for node_id, name, sync_type, hw_usage, buffer_layer, is_output, inputs in a: input_list = [] for src_node, src_id in inputs: if src_node not in node_map: @@ -228,7 +230,7 @@ def load_model(fname: str) -> List[Node]: else: assert node_map[src_node] is not None, "Detected ring in topo order {}->{} !".format(src_node, node_id) input_list.append([node_map[src_node], src_id]) - node = Node(node_id, input_list, name, sync_type) + node = Node(node_id, input_list, name, sync_type, hw_usage, buffer_layer) for src_node, _ in inputs: assert node_map[src_node] is not None graph.add_edge(node_map[src_node], node, weight=0, type="data") diff --git a/tl_pipeline/main.py b/tl_pipeline/main.py index d02f548a9814..d6cce3fb86c5 100644 --- a/tl_pipeline/main.py +++ b/tl_pipeline/main.py @@ -10,8 +10,8 @@ topo_ordered_nodes = list(nx.topological_sort(graph)) print("topo_ordered_nodes:", topo_ordered_nodes) - results = [] - add_backward_edges(graph, graph.copy(), list(reversed(topo_ordered_nodes)), stream=2, cur_id=0, dst_id=0, result=results) + results = transform(graph, graph.copy(), topo_ordered_nodes, stream=2) + plans = [] for i, g in enumerate(results): print(f"Graph {i + 1}:") diff --git a/tl_pipeline/pipeline_transform.py b/tl_pipeline/pipeline_transform.py index 00787ae9e578..28bde4bb6d04 100644 --- a/tl_pipeline/pipeline_transform.py +++ b/tl_pipeline/pipeline_transform.py @@ -11,11 +11,10 @@ def is_edge_valid(graph: 'nx.DiGraph', u: 'Node', v: 'Node', weight: int) -> bool: if graph.has_edge(u, v): return False - if nx.negative_edge_cycle(graph): - return False - # if nx.has_path(graph, u, v): g = graph.copy() g.add_edge(u, v, weight=weight) + if nx.negative_edge_cycle(g): + return False shortest_paths = dict(nx.all_pairs_bellman_ford_path_length(g, weight='weight')) # print("-"*100) @@ -30,10 +29,12 @@ def is_edge_valid(graph: 'nx.DiGraph', u: 'Node', v: 'Node', weight: int) -> boo return False if edge_weight == s_to_d_dist: - if len(list(nx.all_shortest_paths(g, source=s, target=d, weight='weight'))) > 1: - # print("is_edge_valid failed 1") - # print("path:", path) - return False + _g = g.copy() + _g.remove_edge(s, d) + if nx.has_path(_g, s, d): + new_s_to_d_dist = nx.bellman_ford_path_length(_g, s, d, weight='weight') + if new_s_to_d_dist == edge_weight: + return False return True def is_graph_valid(graph: 'nx.DiGraph', original_graph: 'nx.DiGraph'): @@ -65,6 +66,11 @@ def add_backward_edges(graph: 'nx.DiGraph', original_graph: 'nx.DiGraph', ordere add_backward_edges(graph, original_graph, ordered_nodes, stream, cur_id, dst_id + 1, result) return + common_hw = [h for h in u.hw_usage if h in v.hw_usage] + if len(common_hw) == 0 and u.buffer_layer[0] != v.buffer_layer[1]: + add_backward_edges(graph, original_graph, ordered_nodes, stream, cur_id, dst_id + 1, result) + return + add_backward_edges(graph, original_graph, ordered_nodes, stream, cur_id, dst_id + 1, result) for w in range(1, stream + 1): if is_edge_valid(graph, u, v, w): @@ -72,6 +78,45 @@ def add_backward_edges(graph: 'nx.DiGraph', original_graph: 'nx.DiGraph', ordere add_backward_edges(graph, original_graph, ordered_nodes, stream, cur_id, dst_id + 1, result) graph.remove_edge(u, v) +def add_forward_edges(graph: 'nx.DiGraph', original_graph: 'nx.DiGraph', ordered_nodes: List['Node'], stream: int, cur_id: int, dst_id: int, result: List['nx.DiGraph']): + if cur_id == len(ordered_nodes): + if is_graph_valid(graph, original_graph): + result.append(graph.copy()) + return + + if ordered_nodes[cur_id]._sync_type == "sync": + add_forward_edges(graph, original_graph, ordered_nodes, stream, cur_id + 1, 0, result) + return + + if dst_id == len(ordered_nodes): + add_forward_edges(graph, original_graph, ordered_nodes, stream, cur_id + 1, 0, result) + return + + u = ordered_nodes[cur_id] + v = ordered_nodes[dst_id] + if u in list(nx.descendants(original_graph, v)) or u == v: + add_forward_edges(graph, original_graph, ordered_nodes, stream, cur_id, dst_id + 1, result) + return + + common_hw = [h for h in u.hw_usage if h in v.hw_usage] + if len(common_hw) == 0 and u.buffer_layer[0] != v.buffer_layer[1]: + add_forward_edges(graph, original_graph, ordered_nodes, stream, cur_id, dst_id + 1, result) + return + + add_forward_edges(graph, original_graph, ordered_nodes, stream, cur_id, dst_id + 1, result) + for w in range(-stream, 0): + if is_edge_valid(graph, u, v, w): + graph.add_edge(u, v, weight=w, type="control") + add_forward_edges(graph, original_graph, ordered_nodes, stream, cur_id, dst_id + 1, result) + graph.remove_edge(u, v) + +def transform(graph: 'nx.DiGraph', original_graph: 'nx.DiGraph', topo_ordered_nodes: List['Node'], stream: int): + results = [] + interm_results = [] + add_backward_edges(graph, original_graph, list(reversed(topo_ordered_nodes)), stream=stream, cur_id=0, dst_id=0, result=interm_results) + for modified_graph in interm_results: + add_forward_edges(modified_graph, original_graph, list(topo_ordered_nodes), stream=stream, cur_id=0, dst_id=0, result=results) + return results if __name__ == "__main__": _, graph = load_model("inputs.json") From 7aec46f995f192ba214dd078a2479d3d79af644d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 Aug 2024 09:29:04 +0000 Subject: [PATCH 07/23] [tl] Add dp for pipeline scheduling --- tl_pipeline/dp.py | 103 +++++++++++++++++++++++++++++++++++++++++++ tl_pipeline/graph.py | 7 +-- 2 files changed, 107 insertions(+), 3 deletions(-) create mode 100644 tl_pipeline/dp.py diff --git a/tl_pipeline/dp.py b/tl_pipeline/dp.py new file mode 100644 index 000000000000..d759f6751be1 --- /dev/null +++ b/tl_pipeline/dp.py @@ -0,0 +1,103 @@ +import networkx as nx +from graph import * + +class Schedule: + def __init__(self, order, hw_time): + self.order = order + self.hw_time = {} + self.hw_time["tensor_core"] = hw_time["tensor_core"] + self.hw_time["cuda_core"] = hw_time["cuda_core"] + self.hw_time["tma"] = hw_time["tma"] + + @property + def time(self) -> float: + return max(self.hw_time["tensor_core"], self.hw_time["cuda_core"], self.hw_time["tma"]) + + def __repr__(self) -> str: + nodes_str = [repr(node) for node in self.order] + s = ", ".join(nodes_str) + s += "\n" + s += f"Time:{self.time}" + return s + +dp_dict = {} + +def get_schedule(nodes: List[Node], in_node: Node, in_node_dsts: List[Node]) -> Schedule: + schedule = dp_dict[frozenset(nodes)] + new_order = [in_node] + schedule.order + end_time = schedule.hw_time[in_node.hw_usage] + for dst_node in in_node_dsts: + if schedule.hw_time[dst_node.hw_usage] > end_time: + end_time = schedule.hw_time[dst_node.hw_usage] + if in_node._sync_type == "sync": + end_time = schedule.time + start_time = end_time + in_node.time + # start time can not be later than any node in nodes + if start_time < schedule.time: + start_time = schedule.time + + new_hw_time = (schedule.hw_time).copy() + new_hw_time[in_node.hw_usage] = start_time + return Schedule(new_order, new_hw_time) + +def dp(graph: Graph, nodes: List[Node]): + global dp_dict + pred_nodes = {} + + subgraph = graph.subgraph(nodes) + for node in subgraph.nodes: + for pred in graph.predecessors(node): + valid = True + if pred not in subgraph: + for pred_succ in graph.successors(pred): + if pred_succ not in subgraph: + valid = False + break + if valid: + pred_nodes[pred] = [succ for succ in graph.successors(pred) if succ in subgraph] + # for node in pred_nodes: + # print(node) + for in_node, in_node_dsts in pred_nodes.items(): + schedule = get_schedule(nodes, in_node, in_node_dsts) + new_nodes = nodes + [in_node] + new_nodes.append(in_node) + if frozenset(new_nodes) in dp_dict and dp_dict[frozenset(new_nodes)].time < schedule.time: + continue + dp_dict[frozenset(new_nodes)] = schedule + dp(graph, nodes + [in_node]) + +def duplicate_with_stream(graph: Graph, stream: int) -> Tuple[Graph, Node]: + node_num = len(graph.nodes) + node_list = [] + duplicate_graph = nx.DiGraph() + for i in range(stream): + mapping = {node: Node(node.id, [], node.name+f'_{i}', node._sync_type, node.hw_usage, node.time, node.buffer_layer) for node in graph.nodes} + g = nx.relabel_nodes(graph, mapping) + node_list.extend(g.nodes) + duplicate_graph = nx.compose(duplicate_graph, g) + + if i > 0: + for nid in range(node_num): + duplicate_graph.add_edge(node_list[(i - 1) * node_num + nid], node_list[i * node_num + nid], type="control.issue") + + out_nodes = [node for node in duplicate_graph.nodes if duplicate_graph.out_degree(node) == 0] + output_node = Node(node_num * stream, [], "Output", "sync", "cuda_core", 0, []) + + for out_node in out_nodes: + duplicate_graph.add_edge(out_node, output_node, type="data") + return duplicate_graph, output_node + +if __name__ == "__main__": + stream = 3 + _, graph = load_model("inputs.json") + duplicate_graph, output_node = duplicate_with_stream(graph, stream) + + print(duplicate_graph) + print(duplicate_graph.edges(data=True)) + # for node in graph.nodes: + # print(node) + dp_dict[frozenset([output_node])] = Schedule([output_node], {"tensor_core":0, "cuda_core":0, "tma":0}) + dp(duplicate_graph, [output_node]) + + print("dp_dict", dp_dict) + print("results", dp_dict[frozenset(list(duplicate_graph.nodes))]) \ No newline at end of file diff --git a/tl_pipeline/graph.py b/tl_pipeline/graph.py index adc23f59e0dc..f08fe75e92f5 100644 --- a/tl_pipeline/graph.py +++ b/tl_pipeline/graph.py @@ -27,7 +27,7 @@ def __repr__(self) -> str: return "" + self.dst_node.name + ">" class Node: - def __init__(self, node_id: int, inputs: List[Union[Tuple['Node', int], 'Node', None]], name: str, sync_type: str, hw_usage: List[str], buffer_layer: List[str]): + def __init__(self, node_id: int, inputs: List[Union[Tuple['Node', int], 'Node', None]], name: str, sync_type: str, hw_usage: List[str], time: float, buffer_layer: List[str]): self.id = node_id self.name = name self._out_edges = [] @@ -37,6 +37,7 @@ def __init__(self, node_id: int, inputs: List[Union[Tuple['Node', int], 'Node', self._tag = {} self._sync_type = sync_type self.hw_usage = hw_usage + self.time = time self.buffer_layer = buffer_layer assert sync_type in ['sync', 'async'], 'invaild sync_type!' @@ -222,7 +223,7 @@ def load_model(fname: str) -> List[Node]: a = json.load(f) node_map = {item[0] : None for item in a} ordered_nodes = [] - for node_id, name, sync_type, hw_usage, buffer_layer, is_output, inputs in a: + for node_id, name, sync_type, hw_usage, time, buffer_layer, is_output, inputs in a: input_list = [] for src_node, src_id in inputs: if src_node not in node_map: @@ -230,7 +231,7 @@ def load_model(fname: str) -> List[Node]: else: assert node_map[src_node] is not None, "Detected ring in topo order {}->{} !".format(src_node, node_id) input_list.append([node_map[src_node], src_id]) - node = Node(node_id, input_list, name, sync_type, hw_usage, buffer_layer) + node = Node(node_id, input_list, name, sync_type, hw_usage, time, buffer_layer) for src_node, _ in inputs: assert node_map[src_node] is not None graph.add_edge(node_map[src_node], node, weight=0, type="data") From ccef8325b7a22cd611f4c67b293c1ae18e0a7240 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 Aug 2024 17:43:31 +0000 Subject: [PATCH 08/23] [tl] Update dp algorithm --- tl_pipeline/dp.py | 58 +++++++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/tl_pipeline/dp.py b/tl_pipeline/dp.py index d759f6751be1..3a4a919a08a7 100644 --- a/tl_pipeline/dp.py +++ b/tl_pipeline/dp.py @@ -15,36 +15,51 @@ def time(self) -> float: def __repr__(self) -> str: nodes_str = [repr(node) for node in self.order] - s = ", ".join(nodes_str) + s = f"len={len(nodes_str)}, " + s += ", ".join(nodes_str) + s += "\n" + s += f"Time:{self.time}, " + s += f"tensor_core:{self.hw_time['tensor_core']}, " + s += f"cuda_core:{self.hw_time['cuda_core']}, " + s += f"tma:{self.hw_time['tma']}, " s += "\n" - s += f"Time:{self.time}" return s dp_dict = {} +i = 0 +issue_latency = 0.001 -def get_schedule(nodes: List[Node], in_node: Node, in_node_dsts: List[Node]) -> Schedule: +def get_schedule(nodes: List[Node], in_node: Node, in_node_dsts: List[Node], graph: Graph) -> Schedule: schedule = dp_dict[frozenset(nodes)] new_order = [in_node] + schedule.order - end_time = schedule.hw_time[in_node.hw_usage] - for dst_node in in_node_dsts: - if schedule.hw_time[dst_node.hw_usage] > end_time: - end_time = schedule.hw_time[dst_node.hw_usage] - if in_node._sync_type == "sync": - end_time = schedule.time - start_time = end_time + in_node.time - # start time can not be later than any node in nodes - if start_time < schedule.time: - start_time = schedule.time - - new_hw_time = (schedule.hw_time).copy() - new_hw_time[in_node.hw_usage] = start_time + new_hw_time = {"tensor_core":0, "cuda_core":0, "tma":0} + # node: (start_time, end_time) + issue_dict = {} + last_start_time = 0 + for cur_node in new_order: + start_time = max(new_hw_time[cur_node.hw_usage], last_start_time + issue_latency) + for pred_node in new_order: + # Need to issue after previous sync node done + if pred_node in issue_dict and pred_node._sync_type == "sync": + start_time = max(start_time, issue_dict[pred_node][1] + issue_latency) + if graph.has_edge(pred_node, cur_node): + edge = graph.get_edge_data(pred_node, cur_node) + if edge['type'] == "data": + assert pred_node in issue_dict, "error: pred_node not in issue_dict" + # Data denpendency + start_time = max(start_time, issue_dict[pred_node][1] + issue_latency) + issue_dict[cur_node] = [start_time, start_time + cur_node.time] + new_hw_time[cur_node.hw_usage] = start_time + cur_node.time + last_start_time = start_time return Schedule(new_order, new_hw_time) + def dp(graph: Graph, nodes: List[Node]): global dp_dict + global i pred_nodes = {} - subgraph = graph.subgraph(nodes) + subgraph = graph.subgraph(nodes.copy()) for node in subgraph.nodes: for pred in graph.predecessors(node): valid = True @@ -58,9 +73,10 @@ def dp(graph: Graph, nodes: List[Node]): # for node in pred_nodes: # print(node) for in_node, in_node_dsts in pred_nodes.items(): - schedule = get_schedule(nodes, in_node, in_node_dsts) + print(f"iter {i}") + i += 1 + schedule = get_schedule(nodes, in_node, in_node_dsts, graph) new_nodes = nodes + [in_node] - new_nodes.append(in_node) if frozenset(new_nodes) in dp_dict and dp_dict[frozenset(new_nodes)].time < schedule.time: continue dp_dict[frozenset(new_nodes)] = schedule @@ -99,5 +115,7 @@ def duplicate_with_stream(graph: Graph, stream: int) -> Tuple[Graph, Node]: dp_dict[frozenset([output_node])] = Schedule([output_node], {"tensor_core":0, "cuda_core":0, "tma":0}) dp(duplicate_graph, [output_node]) - print("dp_dict", dp_dict) + # for k, v in dp_dict.items(): + # print(k) + # print(v) print("results", dp_dict[frozenset(list(duplicate_graph.nodes))]) \ No newline at end of file From 61b2ef6684aefbf8b4ed9bd159ed5140d694d6f1 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 22 Aug 2024 08:22:00 +0000 Subject: [PATCH 09/23] [tl] Update dp algorithm --- tl_pipeline/dp.py | 119 ++++++++++++++++++++++++++-------------------- 1 file changed, 68 insertions(+), 51 deletions(-) diff --git a/tl_pipeline/dp.py b/tl_pipeline/dp.py index 3a4a919a08a7..1f85bccb0847 100644 --- a/tl_pipeline/dp.py +++ b/tl_pipeline/dp.py @@ -13,9 +13,19 @@ def __init__(self, order, hw_time): def time(self) -> float: return max(self.hw_time["tensor_core"], self.hw_time["cuda_core"], self.hw_time["tma"]) + def __lt__(self, other): + return isinstance(other, Schedule) and self.order != other.order and self.time < other.time + + def __eq__(self, other): + return isinstance(other, Schedule) and self.order == other.order + + def __hash__(self): + return hash(tuple(self.order)) + def __repr__(self) -> str: nodes_str = [repr(node) for node in self.order] - s = f"len={len(nodes_str)}, " + # s = f"len={len(nodes_str)}, " + s = "" s += ", ".join(nodes_str) s += "\n" s += f"Time:{self.time}, " @@ -26,61 +36,66 @@ def __repr__(self) -> str: return s dp_dict = {} +topk = 10 i = 0 issue_latency = 0.001 -def get_schedule(nodes: List[Node], in_node: Node, in_node_dsts: List[Node], graph: Graph) -> Schedule: - schedule = dp_dict[frozenset(nodes)] - new_order = [in_node] + schedule.order - new_hw_time = {"tensor_core":0, "cuda_core":0, "tma":0} - # node: (start_time, end_time) - issue_dict = {} - last_start_time = 0 - for cur_node in new_order: - start_time = max(new_hw_time[cur_node.hw_usage], last_start_time + issue_latency) - for pred_node in new_order: - # Need to issue after previous sync node done - if pred_node in issue_dict and pred_node._sync_type == "sync": - start_time = max(start_time, issue_dict[pred_node][1] + issue_latency) - if graph.has_edge(pred_node, cur_node): - edge = graph.get_edge_data(pred_node, cur_node) - if edge['type'] == "data": - assert pred_node in issue_dict, "error: pred_node not in issue_dict" - # Data denpendency +def get_schedule(nodes: List[Node], in_node: Node, in_node_dsts: List[Node], graph: Graph) -> List[Schedule]: + schedule_list = [] + for schedule in dp_dict[frozenset(nodes)]: + new_order = [in_node] + schedule.order + new_hw_time = {"tensor_core":0, "cuda_core":0, "tma":0} + # node: (start_time, end_time) + issue_dict = {} + last_start_time = 0 + for cur_node in new_order: + start_time = max(new_hw_time[cur_node.hw_usage], last_start_time + issue_latency) + for pred_node in new_order: + # Need to issue after previous sync node done + if pred_node in issue_dict and pred_node._sync_type == "sync": start_time = max(start_time, issue_dict[pred_node][1] + issue_latency) - issue_dict[cur_node] = [start_time, start_time + cur_node.time] - new_hw_time[cur_node.hw_usage] = start_time + cur_node.time - last_start_time = start_time - return Schedule(new_order, new_hw_time) + if graph.has_edge(pred_node, cur_node): + edge = graph.get_edge_data(pred_node, cur_node) + if edge['type'] == "data": + assert pred_node in issue_dict, "error: pred_node not in issue_dict" + # Data denpendency + start_time = max(start_time, issue_dict[pred_node][1] + issue_latency) + issue_dict[cur_node] = [start_time, start_time + cur_node.time] + new_hw_time[cur_node.hw_usage] = start_time + cur_node.time + last_start_time = start_time + schedule_list.append(Schedule(new_order, new_hw_time)) + return sorted(schedule_list) -def dp(graph: Graph, nodes: List[Node]): +def dp(graph: Graph, prev_sub_graphs: List[List[Node]]): global dp_dict global i - pred_nodes = {} - - subgraph = graph.subgraph(nodes.copy()) - for node in subgraph.nodes: - for pred in graph.predecessors(node): - valid = True - if pred not in subgraph: - for pred_succ in graph.successors(pred): - if pred_succ not in subgraph: - valid = False - break - if valid: - pred_nodes[pred] = [succ for succ in graph.successors(pred) if succ in subgraph] - # for node in pred_nodes: - # print(node) - for in_node, in_node_dsts in pred_nodes.items(): - print(f"iter {i}") - i += 1 - schedule = get_schedule(nodes, in_node, in_node_dsts, graph) - new_nodes = nodes + [in_node] - if frozenset(new_nodes) in dp_dict and dp_dict[frozenset(new_nodes)].time < schedule.time: - continue - dp_dict[frozenset(new_nodes)] = schedule - dp(graph, nodes + [in_node]) + # set(sub_graph_nodes): [in_node_0, in_node_1, ...] + sub_graphs_dict = {} + for prev_nodes in prev_sub_graphs: + subgraph = graph.subgraph(prev_nodes.copy()) + for node in subgraph.nodes: + for pred in graph.predecessors(node): + valid = True + if pred not in subgraph: + for pred_succ in graph.successors(pred): + if pred_succ not in subgraph: + valid = False + break + if valid: + key = frozenset(prev_nodes + [pred]) + if key not in sub_graphs_dict: + sub_graphs_dict[key] = [] + sub_graphs_dict[key].append((pred, prev_nodes)) + for new_nodes_set, nodes in sub_graphs_dict.items(): + schedule_list = [] + for in_node, prev_nodes in nodes: + schedule_list.extend(get_schedule(prev_nodes, in_node, [], graph)) + dp_dict[new_nodes_set] = sorted(set(schedule_list))[:topk] + sub_graphs = list(list(n) for n in list(sub_graphs_dict)) + if len(sub_graphs) > 0: + dp(graph, sub_graphs) + def duplicate_with_stream(graph: Graph, stream: int) -> Tuple[Graph, Node]: node_num = len(graph.nodes) @@ -112,10 +127,12 @@ def duplicate_with_stream(graph: Graph, stream: int) -> Tuple[Graph, Node]: print(duplicate_graph.edges(data=True)) # for node in graph.nodes: # print(node) - dp_dict[frozenset([output_node])] = Schedule([output_node], {"tensor_core":0, "cuda_core":0, "tma":0}) - dp(duplicate_graph, [output_node]) + dp_dict[frozenset([output_node])] = [Schedule([output_node], {"tensor_core":0, "cuda_core":0, "tma":0})] + dp(duplicate_graph, [[output_node]]) # for k, v in dp_dict.items(): # print(k) # print(v) - print("results", dp_dict[frozenset(list(duplicate_graph.nodes))]) \ No newline at end of file + print("results:") + for schedule in dp_dict[frozenset(list(duplicate_graph.nodes))]: + print(schedule) \ No newline at end of file From 9db6ddbd1865c32a6cbe164ba655078194dc542c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 23 Aug 2024 08:58:40 +0000 Subject: [PATCH 10/23] [tl] Add memory dimension to dp --- tl_pipeline/dp.py | 135 +++++++++++++++++++++---- tl_pipeline/graph.py | 235 ++++++------------------------------------- 2 files changed, 146 insertions(+), 224 deletions(-) diff --git a/tl_pipeline/dp.py b/tl_pipeline/dp.py index 1f85bccb0847..19a2282ea90a 100644 --- a/tl_pipeline/dp.py +++ b/tl_pipeline/dp.py @@ -1,9 +1,11 @@ import networkx as nx from graph import * +from functools import lru_cache class Schedule: - def __init__(self, order, hw_time): + def __init__(self, order: List[Node], sync: List[Tuple[Node, Node]], hw_time: Dict): self.order = order + self.sync = sync self.hw_time = {} self.hw_time["tensor_core"] = hw_time["tensor_core"] self.hw_time["cuda_core"] = hw_time["cuda_core"] @@ -14,20 +16,23 @@ def time(self) -> float: return max(self.hw_time["tensor_core"], self.hw_time["cuda_core"], self.hw_time["tma"]) def __lt__(self, other): - return isinstance(other, Schedule) and self.order != other.order and self.time < other.time + return isinstance(other, Schedule) and (self.order != other.order or self.sync != other.sync) and self.time < other.time def __eq__(self, other): - return isinstance(other, Schedule) and self.order == other.order + return isinstance(other, Schedule) and self.order == other.order and self.sync == other.sync def __hash__(self): - return hash(tuple(self.order)) + return hash((tuple(self.order), tuple(tuple(pair) for pair in self.sync))) def __repr__(self) -> str: nodes_str = [repr(node) for node in self.order] + sync_str = ["(" + repr(n0) + ", " + repr(n1) + ")" for n0, n1 in self.sync] # s = f"len={len(nodes_str)}, " s = "" s += ", ".join(nodes_str) s += "\n" + s += ",".join(sync_str) + s += "\n" s += f"Time:{self.time}, " s += f"tensor_core:{self.hw_time['tensor_core']}, " s += f"cuda_core:{self.hw_time['cuda_core']}, " @@ -38,11 +43,20 @@ def __repr__(self) -> str: dp_dict = {} topk = 10 i = 0 +smem_interval = 1024 +reg_interval = 1024 +smem_max_cap = 4 * 1024 +reg_max_cap = 4 * 1024 issue_latency = 0.001 -def get_schedule(nodes: List[Node], in_node: Node, in_node_dsts: List[Node], graph: Graph) -> List[Schedule]: +def get_schedule(nodes: List[Node], in_node: Node, graph: nx.DiGraph, mem_cap: List[float]) -> List[Schedule]: schedule_list = [] - for schedule in dp_dict[frozenset(nodes)]: + smem_cap, reg_cap = mem_cap + schedules = [] + for s in range(int(smem_cap / smem_interval)): + for r in range(int(reg_cap / reg_interval)): + schedules.extend(dp_dict[frozenset(nodes)][s][r]) + for schedule in schedules: new_order = [in_node] + schedule.order new_hw_time = {"tensor_core":0, "cuda_core":0, "tma":0} # node: (start_time, end_time) @@ -52,22 +66,91 @@ def get_schedule(nodes: List[Node], in_node: Node, in_node_dsts: List[Node], gra start_time = max(new_hw_time[cur_node.hw_usage], last_start_time + issue_latency) for pred_node in new_order: # Need to issue after previous sync node done - if pred_node in issue_dict and pred_node._sync_type == "sync": + if pred_node in issue_dict and pred_node.sync_type == "sync": start_time = max(start_time, issue_dict[pred_node][1] + issue_latency) if graph.has_edge(pred_node, cur_node): edge = graph.get_edge_data(pred_node, cur_node) if edge['type'] == "data": - assert pred_node in issue_dict, "error: pred_node not in issue_dict" + assert pred_node in issue_dict, "Error: pred_node not in issue_dict" # Data denpendency start_time = max(start_time, issue_dict[pred_node][1] + issue_latency) issue_dict[cur_node] = [start_time, start_time + cur_node.time] new_hw_time[cur_node.hw_usage] = start_time + cur_node.time last_start_time = start_time - schedule_list.append(Schedule(new_order, new_hw_time)) + + new_schedule = Schedule(new_order, schedule.sync, new_hw_time) + smem_fp, reg_fp = calculate_memory_footprint(new_schedule, graph) + + if smem_fp <= smem_cap and reg_fp <= reg_cap: + schedule_list.append(new_schedule) + + # Add syncs conditions for async nodes + if in_node.sync_type == "async": + for sync_dst_node in schedule.order: + if sync_dst_node.name == "Output": + continue + # If no hardware resource conflict, no need to add sync + if sync_dst_node.hw_usage != in_node.hw_usage and sync_dst_node.output_buffer not in in_node.input_buffer: + continue + new_schedule = Schedule(new_order, schedule.sync + [(in_node, sync_dst_node)], new_hw_time) + smem_fp, reg_fp = calculate_memory_footprint(new_schedule, graph) + if smem_fp <= smem_cap and reg_fp <= reg_cap: + schedule_list.append(new_schedule) return sorted(schedule_list) +@lru_cache(None) +def get_data_dependency_inputs(node: Node, graph: nx.DiGraph) -> List[Node]: + return [pred for pred in graph.predecessors(node) if graph.get_edge_data(pred, node).get('type') == 'data'] + +@lru_cache(None) +def get_data_dependency_outputs(node: Node, graph: nx.DiGraph) -> List[Node]: + return [succ for succ in graph.successors(node) if graph.get_edge_data(node, succ).get('type') == 'data'] -def dp(graph: Graph, prev_sub_graphs: List[List[Node]]): +def calculate_memory_footprint(schedule: Schedule, graph: nx.DiGraph) -> List[float]: + smem_fp = 0 + reg_fp = 0 + smem_usage = 0 + reg_usage = 0 + + def allocate_outputs(node: Node): + nonlocal smem_usage, reg_usage + if node.output_buffer == "smem": + smem_usage += node.dsize + elif node.output_buffer == "register": + reg_usage += node.dsize + + def free_outputs(node: Node): + nonlocal smem_usage, reg_usage + if node.output_buffer == "smem": + smem_usage -= node.dsize + elif node.output_buffer == "register": + reg_usage -= node.dsize + + ref_dict = {node: [] for node in schedule.order} + for ordered_node in schedule.order: + for pred in get_data_dependency_inputs(ordered_node, graph): + if pred in schedule.order: + ref_dict[pred].append(ordered_node) + for ordered_node in schedule.order: + sync_srcs = [src for src, dst in schedule.sync if dst == ordered_node] + for pred in list(set(get_data_dependency_inputs(ordered_node, graph) + sync_srcs)): + if pred in schedule.order: + for pred_src in get_data_dependency_inputs(pred, graph): + if pred_src in schedule.order: + if len(ref_dict[pred_src]) == 0: + continue + ref_dict[pred_src].remove(pred) + if len(ref_dict[pred_src]) == 0: + free_outputs(pred_src) + allocate_outputs(ordered_node) + if smem_usage > smem_fp: + smem_fp = smem_usage + if reg_usage > reg_fp: + reg_fp = reg_usage + # print(smem_fp, reg_fp) + return [smem_fp, reg_fp] + +def dp(graph: nx.DiGraph, prev_sub_graphs: List[List[Node]]): global dp_dict global i # set(sub_graph_nodes): [in_node_0, in_node_1, ...] @@ -88,23 +171,31 @@ def dp(graph: Graph, prev_sub_graphs: List[List[Node]]): sub_graphs_dict[key] = [] sub_graphs_dict[key].append((pred, prev_nodes)) for new_nodes_set, nodes in sub_graphs_dict.items(): - schedule_list = [] - for in_node, prev_nodes in nodes: - schedule_list.extend(get_schedule(prev_nodes, in_node, [], graph)) - dp_dict[new_nodes_set] = sorted(set(schedule_list))[:topk] + dp_dict[new_nodes_set] = [] + for s in range(smem_max_cap // smem_interval): + smem_cap = smem_interval * (s + 1) + dp_dict[new_nodes_set].append([]) + for r in range(reg_max_cap // reg_interval): + reg_cap = reg_interval * (r + 1) + schedule_list = [] + for in_node, prev_nodes in nodes: + schedule_list.extend(get_schedule(prev_nodes, in_node, graph, [smem_cap, reg_cap])) + dp_dict[new_nodes_set][s].append(sorted(set(schedule_list))[:topk]) + sub_graphs = list(list(n) for n in list(sub_graphs_dict)) if len(sub_graphs) > 0: dp(graph, sub_graphs) -def duplicate_with_stream(graph: Graph, stream: int) -> Tuple[Graph, Node]: +def duplicate_with_stream(graph: nx.DiGraph, stream: int) -> Tuple[nx.DiGraph, Node]: node_num = len(graph.nodes) node_list = [] duplicate_graph = nx.DiGraph() for i in range(stream): - mapping = {node: Node(node.id, [], node.name+f'_{i}', node._sync_type, node.hw_usage, node.time, node.buffer_layer) for node in graph.nodes} + mapping = {node: Node(node.id, node.name+f'_{i}', node.sync_type, node.hw_usage, node.time, node.input_buffer, node.output_buffer) for node in graph.nodes} g = nx.relabel_nodes(graph, mapping) - node_list.extend(g.nodes) + ordered_ndoes = sorted(g.nodes) + node_list.extend(ordered_ndoes) duplicate_graph = nx.compose(duplicate_graph, g) if i > 0: @@ -112,7 +203,7 @@ def duplicate_with_stream(graph: Graph, stream: int) -> Tuple[Graph, Node]: duplicate_graph.add_edge(node_list[(i - 1) * node_num + nid], node_list[i * node_num + nid], type="control.issue") out_nodes = [node for node in duplicate_graph.nodes if duplicate_graph.out_degree(node) == 0] - output_node = Node(node_num * stream, [], "Output", "sync", "cuda_core", 0, []) + output_node = Node(node_num * stream, "Output", "sync", "cuda_core", 0, ["register"], "register") for out_node in out_nodes: duplicate_graph.add_edge(out_node, output_node, type="data") @@ -120,14 +211,18 @@ def duplicate_with_stream(graph: Graph, stream: int) -> Tuple[Graph, Node]: if __name__ == "__main__": stream = 3 - _, graph = load_model("inputs.json") + graph = load_model("inputs.json") duplicate_graph, output_node = duplicate_with_stream(graph, stream) print(duplicate_graph) print(duplicate_graph.edges(data=True)) # for node in graph.nodes: # print(node) - dp_dict[frozenset([output_node])] = [Schedule([output_node], {"tensor_core":0, "cuda_core":0, "tma":0})] + dp_dict[frozenset([output_node])] = [] + for s in range(smem_max_cap // smem_interval): + dp_dict[frozenset([output_node])].append([]) + for r in range(reg_max_cap // reg_interval): + dp_dict[frozenset([output_node])][s].append([Schedule([output_node], [], {"tensor_core":0, "cuda_core":0, "tma":0})]) dp(duplicate_graph, [[output_node]]) # for k, v in dp_dict.items(): diff --git a/tl_pipeline/graph.py b/tl_pipeline/graph.py index f08fe75e92f5..925b82f08e3b 100644 --- a/tl_pipeline/graph.py +++ b/tl_pipeline/graph.py @@ -1,217 +1,46 @@ from typing import Any, List, Dict, Optional, Tuple, Union -import heapq - -class Edge: - def __init__(self, src_node: 'Node', dst_node: 'Node', src_id: int, dst_id: int, w: int): - self.w = w - self.src_node = src_node - self.dst_node = dst_node - self.src_id = src_id - self.dst_id = dst_id - - def __repr__(self) -> str: - return "" + self.dst_node.name + ">" - -class DataDependencyEdge(Edge): - def __init__(self, src_node: 'Node', dst_node: 'Node', src_id: int, dst_id: int, w: int): - super().__init__(src_node, dst_node, src_id, dst_id, w) - - def __repr__(self) -> str: - return "" + self.dst_node.name + ">" - -class CondDependencyEdge(Edge): - def __init__(self, src_node: 'Node', dst_node: 'Node', src_id: int, dst_id: int, w: int): - super().__init__(src_node, dst_node, src_id, dst_id, w) - - def __repr__(self) -> str: - return "" + self.dst_node.name + ">" +from math import prod class Node: - def __init__(self, node_id: int, inputs: List[Union[Tuple['Node', int], 'Node', None]], name: str, sync_type: str, hw_usage: List[str], time: float, buffer_layer: List[str]): + def __init__( + self, + node_id: int, + name: str, + sync_type: str, + hw_usage: str, + time: float, + input_buffer: List[str], + output_buffer: str + ): self.id = node_id self.name = name - self._out_edges = [] - self._in_edges = [] - self._shapes = [] - self._dtypes = [] - self._tag = {} - self._sync_type = sync_type + # self._out_edges = [] + # self._in_edges = [] + # self._shapes = [] + # self._dtypes = [] + # self._tag = {} + self.shape = [16, 16] + self.dtype = "float16" + self.sync_type = sync_type self.hw_usage = hw_usage self.time = time - self.buffer_layer = buffer_layer + self.input_buffer = input_buffer + self.output_buffer = output_buffer assert sync_type in ['sync', 'async'], 'invaild sync_type!' - for i, node in enumerate(inputs): - if node is None: - inputs[i] = PlaceHolderNode() - - for dst_id, n in enumerate(inputs): - if isinstance(n, Node): - n = (n, 0) - assert(len(n) == 2) - src_node, src_id = n[0], n[1] - edge = DataDependencyEdge(src_node, self, src_id, dst_id, 0) - self._in_edges.append(edge) - src_node._out_edges.append(edge) - @property - def inputs(self) -> List[Edge]: - return self._in_edges - - @property - def outputs(self) -> List[Edge]: - return self._out_edges - - def set_inputs(self, i: int, edge: Edge): - assert i < len(self._in_edges) - self._in_edges[i] = edge - - def set_outputs(self, i: int, edge: Edge): - assert i < len(self._out_edges) - self._out_edges[i] = edge - - def get_shape(self, id: int = 0) -> List[int]: - return self._shapes[id] - - def set_shape(self, shape: List[int], id=0, overwrite=False) -> None: - if len(self._shapes) <= id: - self._shapes.extend([None for _ in range(id - len(self._shapes) + 1)]) - elif self._shapes[id] is not None and not overwrite: - assert self._shapes[id] == list(map(int, shape)), (self._shapes, list(map(int, shape))) - self._shapes[id] = list(map(int, shape)) - - def is_placeholder(self): - return False - - def is_output(self): - return False - - def add_tag(self, k: str, v: Any = True) -> None: - self._tag[k] = v - - def get_tag(self, k: str) -> Any: - if k not in self._tag: - return None - return self._tag[k] - - def num_outputs(self) -> int: - if len(self.outputs) == 0: - return 0 - return max([e.src_id for e in self.outputs]) + 1 - - def get_ir(self) -> str: - raise NotImplementedError() + def dsize(self) -> int: + assert self.dtype in ['float16', 'float32'], 'invaild dtype!' + dtype_size = {'float16': 2, 'float32': 4} + return dtype_size[self.dtype] * prod(self.shape) + def __lt__(self, other: 'Node') -> bool: + return self.id < other.id + def __repr__(self) -> str: # return "" return self.name -class PlaceHolderNode(Node): - def __init__(self, name=""): - super().__init__([], "PlaceHolder " + name, "sync") - - def is_placeholder(self): - return True - - def get_ir(self) -> str: - return "placeholder" - -class OutputNode(Node): - def __init__(self, node, id=0): - super().__init__([(node, id)], "Output ", "sync") - # self.set_shape(node.get_shape(id)) - # self.set_dtype(node.get_dtype(id)) - - def is_output(self): - return True - - def get_ir(self) -> str: - return "output" - -class Graph: - def __init__(self, nodes: List['Node']): - self.nodes = nodes - self.node_count = len(self.nodes) - self.dist = [[float('inf')] * self.node_count for _ in range(self.node_count)] - self.init_dist() - self.edge_count = 0 - for node in self.nodes: - self.edge_count += len(node.outputs) - - print("dist:") - print(self.dist) - print("edge_count:") - print(self.edge_count) - - def init_dist(self): - for i in range(self.node_count): - self.dist[i][i] = 0 - for node in self.nodes: - for edge in node.outputs: - self.add_edge(node, edge.dst_node, edge.w) - - def exist_edge(self, n0: 'Node', n1: 'Node') -> bool: - assert n0 in self.nodes, "n0 not in graph!" - assert n1 in self.nodes, "n1 not in graph!" - for edge in n0.outputs: - if n1 == edge.dst_node: - return True - return False - - def update_dist(self, u, v, w): - dist = self.dist.copy() - if dist[u][v] > w: - dist[u][v] = w - - for i in range(self.node_count): - for j in range(self.node_count): - if dist[i][v] > dist[i][u] + w: - dist[i][v] = dist[i][u] + w - if dist[u][j] > w + dist[v][j]: - dist[u][j] = w + dist[v][j] - if dist[i][j] > dist[i][u] + w + dist[v][j]: - dist[i][j] = dist[i][u] + w + dist[v][j] - - # Detect negetive ring - for i in range(self.node_count): - for j in range(self.node_count): - if dist[i][v] > dist[i][u] + w: - return None - if dist[u][j] > w + dist[v][j]: - return None - if dist[i][j] > dist[i][u] + w + dist[v][j]: - return None - return dist - - - def add_edge(self, n0, n1, w) -> bool: - self.dist[n0.id][n1.id] = w - dist = self.update_dist(n0.id, n1.id, w) - if dist is None: - print("Invalid, negetive ring detected.") - return False # Invalid - self.dist = dist - print("updated dist:") - print(self.dist) - return True - - def del_edge(self, n0, n1): - pass - - def check_legality(self) -> bool: - return True - -def print_graph(nodes: List['Node']): - for node in nodes: - print(node) - for edge in node.inputs: - print(edge) - print('-'*100) - -def get_path_value(n0: 'Node', n1: 'Node') -> int: - - return - import networkx as nx from graph import * @@ -222,8 +51,7 @@ def load_model(fname: str) -> List[Node]: with open(fname) as f: a = json.load(f) node_map = {item[0] : None for item in a} - ordered_nodes = [] - for node_id, name, sync_type, hw_usage, time, buffer_layer, is_output, inputs in a: + for node_id, name, sync_type, hw_usage, input_buffer, output_buffer, inputs, time in a: input_list = [] for src_node, src_id in inputs: if src_node not in node_map: @@ -231,10 +59,9 @@ def load_model(fname: str) -> List[Node]: else: assert node_map[src_node] is not None, "Detected ring in topo order {}->{} !".format(src_node, node_id) input_list.append([node_map[src_node], src_id]) - node = Node(node_id, input_list, name, sync_type, hw_usage, time, buffer_layer) + node = Node(node_id, name, sync_type, hw_usage, time, input_buffer, output_buffer) for src_node, _ in inputs: assert node_map[src_node] is not None graph.add_edge(node_map[src_node], node, weight=0, type="data") node_map[node_id] = node - ordered_nodes.append(node) - return ordered_nodes, graph \ No newline at end of file + return graph \ No newline at end of file From af12dee73e65e5d52a6b29af9c3371d549060156 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Sun, 1 Sep 2024 09:13:35 +0000 Subject: [PATCH 11/23] [tl] Locate bug in FA, init microbenchmarks. --- src/tl/tl_templates/gemm_sm90.h | 18 +- tl_scripts/mha_test.py | 191 +++++++++++++++ tl_scripts/test.py | 54 ----- tl_scripts/torch_ref.py | 132 +++++++++++ tl_verify/compile.py | 46 ++++ tl_verify/cuda_interface.cpp | 34 +++ tl_verify/fa_kernel.cu | 408 ++++++++++++++++++++++++++++++++ tl_verify/fa_kernel.hpp | 25 ++ tl_verify/fa_no_tma.cu | 199 ++++++++++++++++ tl_verify/main.py | 69 ++++++ tl_verify/setup.py | 49 ++++ 11 files changed, 1170 insertions(+), 55 deletions(-) create mode 100644 tl_scripts/mha_test.py delete mode 100644 tl_scripts/test.py create mode 100644 tl_scripts/torch_ref.py create mode 100644 tl_verify/compile.py create mode 100644 tl_verify/cuda_interface.cpp create mode 100644 tl_verify/fa_kernel.cu create mode 100644 tl_verify/fa_kernel.hpp create mode 100644 tl_verify/fa_no_tma.cu create mode 100644 tl_verify/main.py create mode 100644 tl_verify/setup.py diff --git a/src/tl/tl_templates/gemm_sm90.h b/src/tl/tl_templates/gemm_sm90.h index 30ee86808100..372a947579c7 100644 --- a/src/tl/tl_templates/gemm_sm90.h +++ b/src/tl/tl_templates/gemm_sm90.h @@ -50,7 +50,7 @@ template ::value, tfloat32_t, A_type_raw>; - using B_type = conditional_t::value, tfloat32_t, A_type_raw>; + using B_type = conditional_t::value, tfloat32_t, B_type_raw>; using C_type = C_type_raw; static constexpr GMMA::Major GmmaMajorA = trans_A ? GMMA::Major::MN : GMMA::Major::K; @@ -115,6 +115,22 @@ class GemmTensorOp { partition_shape_A(tiled_mma, Shape, Int>{})); Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); + + // warpgroup_fence_operand(tCrA); + // warpgroup_fence_operand(acc); + // for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // warpgroup_arrive(); + // // (V,M) x (V,N) => (V,M,N) + // gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc); + // if(k_block == 0) { + // tiled_mma.accumulate_ = GMMA::ScaleOut::One; + // } + // warpgroup_commit_batch(); + // } + // warpgroup_wait<0>(); + // warpgroup_fence_operand(acc); + // warpgroup_fence_operand(tCrA); + warpgroup_fence_operand(acc); warpgroup_arrive(); diff --git a/tl_scripts/mha_test.py b/tl_scripts/mha_test.py new file mode 100644 index 000000000000..e52733ad3be7 --- /dev/null +++ b/tl_scripts/mha_test.py @@ -0,0 +1,191 @@ +import torch +from tvm import tl +import tvm.tl.language as T +from functools import partial + +# This script gives a wrong result when dim=64. +# The error is due to the acc_s_cast tensor reuse the register of Q_local tensor (don't know why). +# It is a strange error because in PTX file, the register of Q_local and acc_s_cast are different. +# To reproduce the error, you can try the following script: +# with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): +# Q_shared = T.alloc_shared([block_M, dim], dtype) +# Q_local = T.alloc_fragment([block_M, dim], dtype) +# K_shared = T.alloc_shared([block_N, dim], dtype) +# V_shared = T.alloc_shared([block_N, dim], dtype) +# acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) +# acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) +# acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + +# T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) +# T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) +# T.fill(acc_o, 0) +# T.copy(Q_shared, Q_local) +# for i, j in T.Parallel(block_M, dim): +# Q_local[i, j] *= scale +# loop_range = ( +# T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N) +# ) +# for k in T.Pipelined(loop_range, num_stages=1): +# T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) +# if is_casual: +# for i, j in T.Parallel(block_M, block_N): +# acc_s[i, j] = T.if_then_else( +# bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype) +# ) +# else: +# T.clear(acc_s) +# T.gemm(Q_local, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) +# T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) +# for i, j in T.Parallel(block_M, block_N): +# acc_s[i, j] = T.exp2(acc_s[i, j] - 32) +# T.copy(acc_s, acc_s_cast) +# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) +# T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + +# To fix this, we can either use T.gemm(Q_shared, K_shared, acc_s), like in FlashAttention implementation, +# or use different wgmma instrutcion (like M64N32K16) + +def flashattn(batch, heads, seq_len, dim, is_casual, block_M, block_N): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def main( + Q: T.Buffer(shape, dtype), + K: T.Buffer(shape, dtype), + V: T.Buffer(shape, dtype), + Output: T.Buffer(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + Q_local = T.alloc_fragment([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.copy(Q_shared, Q_local) + for i, j in T.Parallel(block_M, dim): + Q_local[i, j] *= scale + loop_range = ( + T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N) + ) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_casual: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype) + ) + else: + T.clear(acc_s) + T.gemm(Q_local, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + + return main + + +def ref_program(Q, K, V, casual): + # from flash_attn.flash_attn_interface import flash_attn_func + + # return flash_attn_func(Q, K, V, causal=casual) + assert casual == False, "casual is not supported" + batch, seq_len, heads, dim = Q.size() + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + block_M = seq_len + block_N = 64 if dim <= 128 else 32 + acc_s = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float) + acc_s_cast = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float16) + acc_o = torch.empty((batch, block_M, heads, dim), device="cuda", dtype=torch.float) + scores_max = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) + scores_max_prev = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) + scores_scale = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) + scores_sum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) + logsum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) + acc_o.fill_(0) + logsum.fill_(0) + scores_max.fill_(float('-inf')) + Q_scaled = Q * scale + + for i in range(int(seq_len / block_N)): + acc_s.fill_(0) + acc_s = torch.einsum('bqhd,bkhd->bhqk', Q_scaled, K[:, i * block_N : (i + 1) * block_N, :, :]) # [batch, seqlen, heads, block_N] + scores_max_prev = scores_max + scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] + scores_scale = torch.exp2(scores_max_prev - scores_max) + acc_o *= scores_scale[:, :, :, None].transpose(1, 2) + acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) + acc_s_cast = acc_s.to(torch.float16) + acc_o += torch.einsum('bhqk,bkhd->bqhd', acc_s_cast, V[:, i * block_N : (i + 1) * block_N, :, :]) + scores_sum = acc_s.sum(dim=-1, keepdim=False) + logsum = logsum * scores_scale + scores_sum + acc_o /= logsum[:, :, :, None].transpose(1, 2) + return acc_o.to(torch.float16) + +# def ref_program(Q, K, V, casual): +# dim = Q.size(-1) +# scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + +# # Step 2: Scale the scores by the square root of dim +# scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + +# # Step 3: Apply softmax to get the attention weights +# attention_weights = F.softmax(scores, dim=-1) + +# # Step 4: Multiply the attention weights by the values (V) +# # This gives us the final output of shape [batch, seq_len, heads, dim] +# output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + +# return output + +if __name__ == "__main__": + BATCH, H, N_CTX, D_HEAD = 1, 1, 64, 64 + casual = False + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 2 * flops_per_matmul + if casual: + total_flops *= 0.5 + BLOCK_M = 64 + BLOCK_N = 64 if D_HEAD <= 128 else 32 + program = flashattn(BATCH, H, N_CTX, D_HEAD, casual, BLOCK_M, BLOCK_N) + ref_program = partial(ref_program, casual=casual) + mod, params = tl.lower(program) + mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) + mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) + + # latency = mod.do_bench(ref_program, warmup=500) + # print("{:.2f} ms".format(latency)) + # print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) + # latency = mod.do_bench(mod) + # print("{:.2f} ms".format(latency)) + # print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) \ No newline at end of file diff --git a/tl_scripts/test.py b/tl_scripts/test.py deleted file mode 100644 index d71a20d8c01d..000000000000 --- a/tl_scripts/test.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -import random -import numpy as np - -def set_seed(seed): - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. - np.random.seed(seed) # Numpy module. - random.seed(seed) # Python random module. - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False -set_seed(42) - -SEQLEN = 1024 -D = 64 -blockM = SEQLEN -blockN = 64 -Q = torch.randn((SEQLEN, D), dtype=torch.float16) -K = torch.randn((SEQLEN, D), dtype=torch.float16) -V = torch.randn((SEQLEN, D), dtype=torch.float16) - -def ref_program(Q, K, V): - qk = torch.matmul(Q, K.transpose(-1, -2)) - m = qk.max(dim=-1, keepdim=True).values - p = torch.exp(qk - m) - s = p / p.sum(dim=-1, keepdim=True) - o = torch.matmul(s, V) - return o - -def test_program(Q, K, V): - lse = torch.randn((blockM), dtype=float) - m = torch.randn((blockM), dtype=float) - m_new = torch.randn((blockM), dtype=float) - acc_o = torch.randn((blockM, D), dtype=float) - m.fill_(float('-inf')) - m_new.fill_(float('-inf')) - lse.fill_(float('-inf')) - acc_o.fill_(float(0)) - for i in range(int(SEQLEN / blockN)): - qk = torch.matmul(Q, (K[i * blockN : (i + 1) * blockN, :]).transpose(-1, -2)) # [blockM, blockN] - m_new = torch.max(qk.max(dim=-1, keepdim=False).values, m_new) # [blockM] - p = torch.exp(qk - m_new.unsqueeze(dim=1)) # [blockM, blockN] - lse = m_new + torch.log(torch.exp(lse - m_new) + p.sum(dim=-1, keepdim=False)) # [blockM] - acc_o = acc_o * torch.exp(m - m_new).unsqueeze(1) - m = m_new - acc_o += torch.matmul(p.to(torch.float16), V[i * blockN : (i + 1) * blockN, :]) - acc_o = acc_o * torch.exp(m_new - lse).unsqueeze(1) - return acc_o.to(torch.float16) - -ref_output = ref_program(Q, K, V) -test_output = test_program(Q, K, V) -are_close = torch.allclose(ref_output, test_output, rtol=1e-03, atol=1e-03) -print(f"Are the outputs close? {are_close}") \ No newline at end of file diff --git a/tl_scripts/torch_ref.py b/tl_scripts/torch_ref.py new file mode 100644 index 000000000000..704a934d0be9 --- /dev/null +++ b/tl_scripts/torch_ref.py @@ -0,0 +1,132 @@ +import torch +import random +import numpy as np + +def set_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + np.random.seed(seed) # Numpy module. + random.seed(seed) # Python random module. + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False +set_seed(42) + +# SEQLEN = 1024 +# D = 64 +# blockM = SEQLEN +# blockN = 64 +# Q = torch.randn((SEQLEN, D), dtype=torch.float16) +# K = torch.randn((SEQLEN, D), dtype=torch.float16) +# V = torch.randn((SEQLEN, D), dtype=torch.float16) + +# def ref_program(Q, K, V): +# qk = torch.matmul(Q, K.transpose(-1, -2)) +# m = qk.max(dim=-1, keepdim=True).values +# p = torch.exp(qk - m) +# s = p / p.sum(dim=-1, keepdim=True) +# o = torch.matmul(s, V) +# return o + +# def test_program(Q, K, V): +# lse = torch.randn((blockM), dtype=float) +# m = torch.randn((blockM), dtype=float) +# m_new = torch.randn((blockM), dtype=float) +# acc_o = torch.randn((blockM, D), dtype=float) +# m.fill_(float('-inf')) +# m_new.fill_(float('-inf')) +# lse.fill_(float('-inf')) +# acc_o.fill_(float(0)) +# for i in range(int(SEQLEN / blockN)): +# qk = torch.matmul(Q, (K[i * blockN : (i + 1) * blockN, :]).transpose(-1, -2)) # [blockM, blockN] +# m_new = torch.max(qk.max(dim=-1, keepdim=False).values, m_new) # [blockM] +# p = torch.exp(qk - m_new.unsqueeze(dim=1)) # [blockM, blockN] +# lse = m_new + torch.log(torch.exp(lse - m_new) + p.sum(dim=-1, keepdim=False)) # [blockM] +# acc_o = acc_o * torch.exp(m - m_new).unsqueeze(1) +# m = m_new +# acc_o += torch.matmul(p.to(torch.float16), V[i * blockN : (i + 1) * blockN, :]) +# acc_o = acc_o * torch.exp(m_new - lse).unsqueeze(1) +# return acc_o.to(torch.float16) + +# ref_output = ref_program(Q, K, V) +# test_output = test_program(Q, K, V) +# are_close = torch.allclose(ref_output, test_output, rtol=1e-03, atol=1e-03) +# print(f"Are the outputs close? {are_close}") + +import torch.nn.functional as F + +batch = 1 +seq_len = 1024 +heads = 1 +dim = 64 +shape = [batch, seq_len, heads, dim] +Q = torch.randn(shape, device="cuda", dtype=torch.float16) +K = torch.randn(shape, device="cuda", dtype=torch.float16) +V = torch.randn(shape, device="cuda", dtype=torch.float16) +# Q = torch.ones(shape, device="cuda", dtype=torch.float16) +# K = torch.ones(shape, device="cuda", dtype=torch.float16) +# V = torch.ones(shape, device="cuda", dtype=torch.float16) + +def test_program(Q, K, V): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + block_M = seq_len + block_N = 64 + acc_s = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float) + acc_s_cast = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float16) + acc_o = torch.empty((batch, block_M, heads, dim), device="cuda", dtype=torch.float) + scores_max = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) + scores_max_prev = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) + scores_scale = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) + scores_sum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) + logsum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) + acc_o.fill_(0) + logsum.fill_(0) + scores_max.fill_(float('-inf')) + Q *= scale + + for i in range(int(seq_len / block_N)): + acc_s.fill_(0) + acc_s = torch.einsum('bqhd,bkhd->bhqk', Q, K[:, i * block_N : (i + 1) * block_N, :, :]) # [batch, seqlen, heads, block_N] + scores_max_prev = scores_max + scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] + scores_scale = torch.exp2(scores_max_prev - scores_max) + acc_o *= scores_scale[:, :, :, None].transpose(1, 2) + acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) + # print("acc_s:", acc_s) + acc_s_cast = acc_s.to(torch.float16) + acc_o += torch.einsum('bhqk,bkhd->bqhd', acc_s_cast, V[:, i * block_N : (i + 1) * block_N, :, :]) + scores_sum = acc_s.sum(dim=-1, keepdim=False) + logsum = logsum * scores_scale + scores_sum + # print("acc_o:", acc_o.size()) + # print("logsum:", logsum.size()) + acc_o /= logsum[:, :, :, None].transpose(1, 2) + return acc_o.to(torch.float16) + + +def ref_program(Q, K, V): + dim = Q.size(-1) + scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + + + # Step 2: Scale the scores by the square root of dim + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + + + # Step 3: Apply softmax to get the attention weights + attention_weights = F.softmax(scores, dim=-1) + + # print("scores:", attention_weights) + # Step 4: Multiply the attention weights by the values (V) + # This gives us the final output of shape [batch, seq_len, heads, dim] + output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + + + return output + +ref_output = ref_program(Q, K, V) +test_output = test_program(Q, K, V) +are_close = torch.allclose(ref_output, test_output, rtol=1e-03, atol=1e-03) +print(f"Are the outputs close? {are_close}") + +print("ref_output:", ref_output) +print("test_output:", test_output) \ No newline at end of file diff --git a/tl_verify/compile.py b/tl_verify/compile.py new file mode 100644 index 000000000000..6504e61656bc --- /dev/null +++ b/tl_verify/compile.py @@ -0,0 +1,46 @@ +import os +import os.path as osp +# from tvm.contrib import nvcc +import subprocess + +with open("gemmx1.cu", "r") as f: + code = f.read() + +tvm_root = osp.join(osp.dirname(__file__), "../..") +tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl")) +if "TL_CUTLASS_PATH" in os.environ: + cutlass_path = os.environ["TL_CUTLASS_PATH"] +else: + cutlass_path = osp.abspath(osp.join(tvm_root, "3rdparty/cutlass/include")) + + +format = "ptx" +arch = f"sm_90a" + +# print(tl_template_path) +# print(cutlass_path) + +nvcc_command = [ + "nvcc", + "-o", "gemmx1", + "-arch=" + arch, + "--use_fast_math", + "-std=c++17", + "-I" + tl_template_path, + "-I" + cutlass_path, + "-lcuda", + "gemmx1.cu" +] + +subprocess.run(nvcc_command, check=True) + +# nvcc -ptx fa_kernel.cu -o fa_kernel.ptx -O3 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_BFLOAT16_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ -U__CUDA_NO_BFLOAT162_OPERATORS__ -U__CUDA_NO_BFLOAT162_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda -arch=sm_90a --use_fast_math -std=c++17 -I/home/msra/cy/tvm.tl/src/tl -I/home/msra/cy/tvm.tl/cutlass/include -lcuda +"-O3", +"-U__CUDA_NO_HALF_OPERATORS__", +"-U__CUDA_NO_HALF_CONVERSIONS__", +"-U__CUDA_NO_BFLOAT16_OPERATORS__", +"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", +"-U__CUDA_NO_BFLOAT162_OPERATORS__", +"-U__CUDA_NO_BFLOAT162_CONVERSIONS__", +"--expt-relaxed-constexpr", +"--expt-extended-lambda", \ No newline at end of file diff --git a/tl_verify/cuda_interface.cpp b/tl_verify/cuda_interface.cpp new file mode 100644 index 000000000000..1d2b7fcce7ec --- /dev/null +++ b/tl_verify/cuda_interface.cpp @@ -0,0 +1,34 @@ +#include +#include +#include +#include +#include +#include "fa_kernel.hpp" + +void main_kernel_launcher(at::Tensor Q, at::Tensor K, at::Tensor V, at::Tensor output); +void main_kernel_launcher_no_tma(at::Tensor Q, at::Tensor K, at::Tensor V, at::Tensor output); + +at::Tensor kernel_function(at::Tensor Q, at::Tensor K, at::Tensor V) { + at::Tensor output = torch::empty_like(Q); + main_kernel_launcher(Q, K, V, output); + return output; +} + +at::Tensor kernel_function_no_tma(at::Tensor Q, at::Tensor K, at::Tensor V) { + at::Tensor output = torch::empty_like(Q); + main_kernel_launcher_no_tma(Q, K, V, output); + return output; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("kernel_function", &kernel_function, "FA Kernel Function"); + m.def("kernel_function_no_tma", &kernel_function_no_tma, "FA Kernel Launcher"); +} + +void main_kernel_launcher(at::Tensor Q, at::Tensor K, at::Tensor V, at::Tensor output) { + host_function(Flash_fwd_params{Q.data_ptr(), K.data_ptr(), V.data_ptr(), output.data_ptr(), Q.size(0), Q.size(1), Q.size(2), Q.size(3), 64, 64}); +} + +void main_kernel_launcher_no_tma(at::Tensor Q, at::Tensor K, at::Tensor V, at::Tensor output) { + host_function_no_tma(Flash_fwd_params{Q.data_ptr(), K.data_ptr(), V.data_ptr(), output.data_ptr(), Q.size(0), Q.size(1), Q.size(2), Q.size(3), 64, 64}); +} \ No newline at end of file diff --git a/tl_verify/fa_kernel.cu b/tl_verify/fa_kernel.cu new file mode 100644 index 000000000000..4709c03f92fa --- /dev/null +++ b/tl_verify/fa_kernel.cu @@ -0,0 +1,408 @@ +#include +#include +#include +#include +#include +#include +#include "fa_kernel.hpp" + +extern "C" __global__ void __launch_bounds__(128) main_kernel(__grid_constant__ const CUtensorMap K_desc, half_t* __restrict__ Output, __grid_constant__ const CUtensorMap Q_desc, __grid_constant__ const CUtensorMap V_desc) { +} + +template +static std::string ArrayToStr(const T* ptr, size_t n) { + std::stringstream ss; + ss << "["; + for (size_t i = 0; i < n; i++) { + if (i > 0) ss << ", "; + ss << ptr[i]; + } + ss << "]"; + return ss.str(); +} + +struct TensorMapArgs { + CUtensorMap* map; + CUtensorMapDataType type; + cuuint32_t tensorRank; + void* globalAddress; + cuuint64_t globalDim[5], globalStride[5]; + cuuint32_t boxDim[5], elementStrides[5]; + CUtensorMapInterleave interleave; + CUtensorMapSwizzle swizzle; + CUtensorMapL2promotion l2Promotion; + CUtensorMapFloatOOBfill oobFill; + + std::string ToDebugString() { + std::stringstream ss; + ss << "TMA Desc Addr: " << map << std::endl + << "format " << type << std::endl + << "dim " << tensorRank << std::endl + << "gmem_address " << globalAddress << std::endl + << "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl + << "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl + << "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl + << "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl + << "interleave " << interleave << std::endl + << "swizzle " << swizzle << std::endl + << "l2Promotion " << l2Promotion << std::endl + << "oobFill " << oobFill << std::endl; + return ss.str(); + } +}; + +void host_function(Flash_fwd_params params) { + int num_m_blocks = (params.seq_len + params.block_M - 1) / params.block_M; + dim3 grid(num_m_blocks, params.head, params.batch); + dim3 block(128); + size_t sharedMemSize = (params.block_M + 2 * params.block_N) * params.dim * sizeof(half_t); // 24576; + + // int size = params.batch * params.head * params.seq_len * params.dim * sizeof(half_t); + + CUtensorMap Q_desc = {0}; + CUtensorMap K_desc = {0}; + CUtensorMap V_desc = {0}; + TensorMapArgs Q_arg; + TensorMapArgs K_arg; + TensorMapArgs V_arg; + + Q_arg.map = &Q_desc; + Q_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + Q_arg.tensorRank = 4; + Q_arg.globalAddress = params.q_ptr; + Q_arg.globalDim[0] = static_cast(params.dim); + Q_arg.globalDim[1] = static_cast(params.head); + Q_arg.globalDim[2] = static_cast(params.seq_len); + Q_arg.globalDim[3] = static_cast(params.batch); + Q_arg.globalStride[0] = static_cast(2); + Q_arg.globalStride[1] = static_cast(128); + Q_arg.globalStride[2] = static_cast(128); + Q_arg.globalStride[3] = static_cast(32768); + Q_arg.boxDim[0] = static_cast(64); + Q_arg.boxDim[1] = static_cast(1); + Q_arg.boxDim[2] = static_cast(64); + Q_arg.boxDim[3] = static_cast(1); + Q_arg.elementStrides[0] = static_cast(1); + Q_arg.elementStrides[1] = static_cast(1); + Q_arg.elementStrides[2] = static_cast(1); + Q_arg.elementStrides[3] = static_cast(1); + Q_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + Q_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + Q_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + Q_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + K_arg.map = &K_desc; + K_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + K_arg.tensorRank = 4; + K_arg.globalAddress = params.k_ptr; + K_arg.globalDim[0] = static_cast(64); + K_arg.globalDim[1] = static_cast(1); + K_arg.globalDim[2] = static_cast(256); + K_arg.globalDim[3] = static_cast(1); + K_arg.globalStride[0] = static_cast(2); + K_arg.globalStride[1] = static_cast(128); + K_arg.globalStride[2] = static_cast(128); + K_arg.globalStride[3] = static_cast(32768); + K_arg.boxDim[0] = static_cast(64); + K_arg.boxDim[1] = static_cast(1); + K_arg.boxDim[2] = static_cast(64); + K_arg.boxDim[3] = static_cast(1); + K_arg.elementStrides[0] = static_cast(1); + K_arg.elementStrides[1] = static_cast(1); + K_arg.elementStrides[2] = static_cast(1); + K_arg.elementStrides[3] = static_cast(1); + K_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + K_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + K_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + K_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + V_arg.map = &V_desc; + V_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + V_arg.tensorRank = 4; + V_arg.globalAddress = params.v_ptr; + V_arg.globalDim[0] = static_cast(64); + V_arg.globalDim[1] = static_cast(1); + V_arg.globalDim[2] = static_cast(256); + V_arg.globalDim[3] = static_cast(1); + V_arg.globalStride[0] = static_cast(2); + V_arg.globalStride[1] = static_cast(128); + V_arg.globalStride[2] = static_cast(128); + V_arg.globalStride[3] = static_cast(32768); + V_arg.boxDim[0] = static_cast(64); + V_arg.boxDim[1] = static_cast(1); + V_arg.boxDim[2] = static_cast(64); + V_arg.boxDim[3] = static_cast(1); + V_arg.elementStrides[0] = static_cast(1); + V_arg.elementStrides[1] = static_cast(1); + V_arg.elementStrides[2] = static_cast(1); + V_arg.elementStrides[3] = static_cast(1); + V_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + V_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + V_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + V_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + CUresult result; + result = cuTensorMapEncodeTiled( + Q_arg.map, Q_arg.type, Q_arg.tensorRank, Q_arg.globalAddress, Q_arg.globalDim, Q_arg.globalStride + 1, Q_arg.boxDim, + Q_arg.elementStrides, Q_arg.interleave, Q_arg.swizzle, Q_arg.l2Promotion, Q_arg.oobFill); + if (result != CUDA_SUCCESS) { + std::cout << "Failed to initialize the TMA descriptor " << result << std::endl + << Q_arg.ToDebugString(); + } + + result = cuTensorMapEncodeTiled( + K_arg.map, K_arg.type, K_arg.tensorRank, K_arg.globalAddress, K_arg.globalDim, K_arg.globalStride + 1, K_arg.boxDim, + K_arg.elementStrides, K_arg.interleave, K_arg.swizzle, K_arg.l2Promotion, K_arg.oobFill); + if (result != CUDA_SUCCESS) { + std::cout << "Failed to initialize the TMA descriptor " << result << std::endl + << K_arg.ToDebugString(); + } + + result = cuTensorMapEncodeTiled( + V_arg.map, V_arg.type, V_arg.tensorRank, V_arg.globalAddress, V_arg.globalDim, V_arg.globalStride + 1, V_arg.boxDim, + V_arg.elementStrides, V_arg.interleave, V_arg.swizzle, V_arg.l2Promotion, V_arg.oobFill); + if (result != CUDA_SUCCESS) { + std::cout << "Failed to initialize the TMA descriptor " << result << std::endl + << V_arg.ToDebugString(); + } + + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "CUDA device synchronization failed: " << cudaGetErrorString(err) << std::endl; + return; + } + + main_kernel<<>>(K_desc, (half_t*)params.output_ptr, Q_desc, V_desc); + + err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "CUDA kernel launch failed: " << cudaGetErrorString(err) << std::endl; + return; + } + + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "CUDA device synchronization failed: " << cudaGetErrorString(err) << std::endl; + return; + } +} +// template +// static std::string ArrayToStr(const T* ptr, size_t n) { +// std::stringstream ss; +// ss << "["; +// for (size_t i = 0; i < n; i++) { +// if (i > 0) ss << ", "; +// ss << ptr[i]; +// } +// ss << "]"; +// return ss.str(); +// } + +// struct TensorMapArgs { +// CUtensorMap* map; +// CUtensorMapDataType type; +// cuuint32_t tensorRank; +// void* globalAddress; +// cuuint64_t globalDim[5], globalStride[5]; +// cuuint32_t boxDim[5], elementStrides[5]; +// CUtensorMapInterleave interleave; +// CUtensorMapSwizzle swizzle; +// CUtensorMapL2promotion l2Promotion; +// CUtensorMapFloatOOBfill oobFill; + +// std::string ToDebugString() { +// std::stringstream ss; +// ss << "TMA Desc Addr: " << map << std::endl +// << "format " << type << std::endl +// << "dim " << tensorRank << std::endl +// << "gmem_address " << globalAddress << std::endl +// << "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl +// << "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl +// << "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl +// << "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl +// << "interleave " << interleave << std::endl +// << "swizzle " << swizzle << std::endl +// << "l2Promotion " << l2Promotion << std::endl +// << "oobFill " << oobFill << std::endl; +// return ss.str(); +// } +// }; + +// __global__ void fillWithOnes(void *ptr, size_t size) { +// half_t *data = (half_t *)ptr; +// size_t index = threadIdx.x + blockIdx.x * blockDim.x; +// if (index < size) { +// data[index] = 1; +// } +// } + +// int main() { +// dim3 grid(4); +// dim3 block(129); +// size_t sharedMemSize = 24576; + +// int batch = 1; +// int head = 1; +// int seq_len = 256; +// int dim = 64; +// int size = batch * head * seq_len * dim * sizeof(half_t); + +// void *Q, *K, *V, *d_output; +// void *h_output; +// h_output = (void*)malloc(size); +// cudaMalloc((void**)&Q, size); +// cudaMalloc((void**)&K, size); +// cudaMalloc((void**)&V, size); +// cudaMalloc((void**)&d_output, size); + + +// int threadsPerBlock = 256; +// int blocksPerGrid = (batch * head * seq_len * dim + threadsPerBlock - 1) / threadsPerBlock; +// fillWithOnes<<>>(Q, batch * head * seq_len * dim + threadsPerBlock); +// fillWithOnes<<>>(K, batch * head * seq_len * dim + threadsPerBlock); +// fillWithOnes<<>>(V, batch * head * seq_len * dim + threadsPerBlock); + +// cudaError_t err = cudaDeviceSynchronize(); +// if (err != cudaSuccess) { +// std::cerr << "fillWithOnes failed: " << cudaGetErrorString(err) << std::endl; +// return 1; +// } + +// CUtensorMap Q_desc = {0}; +// CUtensorMap K_desc = {0}; +// CUtensorMap V_desc = {0}; +// TensorMapArgs Q_arg; +// TensorMapArgs K_arg; +// TensorMapArgs V_arg; + +// Q_arg.map = &Q_desc; +// Q_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; +// Q_arg.tensorRank = 4; +// Q_arg.globalAddress = Q; +// Q_arg.globalDim[0] = static_cast(64); +// Q_arg.globalDim[1] = static_cast(1); +// Q_arg.globalDim[2] = static_cast(256); +// Q_arg.globalDim[3] = static_cast(1); +// Q_arg.globalStride[0] = static_cast(2); +// Q_arg.globalStride[1] = static_cast(128); +// Q_arg.globalStride[2] = static_cast(128); +// Q_arg.globalStride[3] = static_cast(32768); +// Q_arg.boxDim[0] = static_cast(64); +// Q_arg.boxDim[1] = static_cast(1); +// Q_arg.boxDim[2] = static_cast(64); +// Q_arg.boxDim[3] = static_cast(1); +// Q_arg.elementStrides[0] = static_cast(1); +// Q_arg.elementStrides[1] = static_cast(1); +// Q_arg.elementStrides[2] = static_cast(1); +// Q_arg.elementStrides[3] = static_cast(1); +// Q_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; +// Q_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; +// Q_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; +// Q_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + +// K_arg.map = &K_desc; +// K_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; +// K_arg.tensorRank = 4; +// K_arg.globalAddress = K; +// K_arg.globalDim[0] = static_cast(64); +// K_arg.globalDim[1] = static_cast(1); +// K_arg.globalDim[2] = static_cast(256); +// K_arg.globalDim[3] = static_cast(1); +// K_arg.globalStride[0] = static_cast(2); +// K_arg.globalStride[1] = static_cast(128); +// K_arg.globalStride[2] = static_cast(128); +// K_arg.globalStride[3] = static_cast(32768); +// K_arg.boxDim[0] = static_cast(64); +// K_arg.boxDim[1] = static_cast(1); +// K_arg.boxDim[2] = static_cast(64); +// K_arg.boxDim[3] = static_cast(1); +// K_arg.elementStrides[0] = static_cast(1); +// K_arg.elementStrides[1] = static_cast(1); +// K_arg.elementStrides[2] = static_cast(1); +// K_arg.elementStrides[3] = static_cast(1); +// K_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; +// K_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; +// K_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; +// K_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + +// V_arg.map = &V_desc; +// V_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; +// V_arg.tensorRank = 4; +// V_arg.globalAddress = V; +// V_arg.globalDim[0] = static_cast(64); +// V_arg.globalDim[1] = static_cast(1); +// V_arg.globalDim[2] = static_cast(256); +// V_arg.globalDim[3] = static_cast(1); +// V_arg.globalStride[0] = static_cast(2); +// V_arg.globalStride[1] = static_cast(128); +// V_arg.globalStride[2] = static_cast(128); +// V_arg.globalStride[3] = static_cast(32768); +// V_arg.boxDim[0] = static_cast(64); +// V_arg.boxDim[1] = static_cast(1); +// V_arg.boxDim[2] = static_cast(64); +// V_arg.boxDim[3] = static_cast(1); +// V_arg.elementStrides[0] = static_cast(1); +// V_arg.elementStrides[1] = static_cast(1); +// V_arg.elementStrides[2] = static_cast(1); +// V_arg.elementStrides[3] = static_cast(1); +// V_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; +// V_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; +// V_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; +// V_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + +// CUresult result; +// result = cuTensorMapEncodeTiled( +// Q_arg.map, Q_arg.type, Q_arg.tensorRank, Q_arg.globalAddress, Q_arg.globalDim, Q_arg.globalStride + 1, Q_arg.boxDim, +// Q_arg.elementStrides, Q_arg.interleave, Q_arg.swizzle, Q_arg.l2Promotion, Q_arg.oobFill); +// if (result != CUDA_SUCCESS) { +// std::cout << "Failed to initialize the TMA descriptor " << result << std::endl +// << Q_arg.ToDebugString(); +// } + +// result = cuTensorMapEncodeTiled( +// K_arg.map, K_arg.type, K_arg.tensorRank, K_arg.globalAddress, K_arg.globalDim, K_arg.globalStride + 1, K_arg.boxDim, +// K_arg.elementStrides, K_arg.interleave, K_arg.swizzle, K_arg.l2Promotion, K_arg.oobFill); +// if (result != CUDA_SUCCESS) { +// std::cout << "Failed to initialize the TMA descriptor " << result << std::endl +// << K_arg.ToDebugString(); +// } + +// result = cuTensorMapEncodeTiled( +// V_arg.map, V_arg.type, V_arg.tensorRank, V_arg.globalAddress, V_arg.globalDim, V_arg.globalStride + 1, V_arg.boxDim, +// V_arg.elementStrides, V_arg.interleave, V_arg.swizzle, V_arg.l2Promotion, V_arg.oobFill); +// if (result != CUDA_SUCCESS) { +// std::cout << "Failed to initialize the TMA descriptor " << result << std::endl +// << V_arg.ToDebugString(); +// } + +// if (err != cudaSuccess) { +// std::cerr << "CUDA device synchronization failed: " << cudaGetErrorString(err) << std::endl; +// return 1; +// } + +// main_kernel<<>>(K_desc, (half_t*)d_output, Q_desc, V_desc); + +// err = cudaGetLastError(); +// if (err != cudaSuccess) { +// std::cerr << "CUDA kernel launch failed: " << cudaGetErrorString(err) << std::endl; +// return 1; +// } + +// err = cudaDeviceSynchronize(); +// if (err != cudaSuccess) { +// std::cerr << "CUDA device synchronization failed: " << cudaGetErrorString(err) << std::endl; +// return 1; +// } + +// cudaMemcpy((void*)h_output, (void*)d_output, size, cudaMemcpyDeviceToHost); + +// std::cout << "CUDA kernel executed successfully." << std::endl; +// for (int i = 0; i < seq_len; i++) { +// for (int j = 0; j < dim; j++) { +// std::cout << ((half_t*)h_output)[i * dim + j] << " "; +// } +// std::cout << std::endl; +// } +// std::cout << std::endl; +// return 0; +// } \ No newline at end of file diff --git a/tl_verify/fa_kernel.hpp b/tl_verify/fa_kernel.hpp new file mode 100644 index 000000000000..0cd6f228cc3c --- /dev/null +++ b/tl_verify/fa_kernel.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +struct Flash_fwd_params +{ + using index_t = int64_t; + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + void *__restrict__ output_ptr; + + index_t batch; + index_t seq_len; + index_t head; + index_t dim; + index_t block_M; + index_t block_N; +}; + +void host_function(Flash_fwd_params params); +void host_function_no_tma(Flash_fwd_params params); + diff --git a/tl_verify/fa_no_tma.cu b/tl_verify/fa_no_tma.cu new file mode 100644 index 000000000000..65ae4030f96f --- /dev/null +++ b/tl_verify/fa_no_tma.cu @@ -0,0 +1,199 @@ +#include +#include +#include +#include +#include +#include +#include "fa_kernel.hpp" + +template +__device__ void print_(T* reg, const char* name, int range, int total_threads) { + __syncthreads(); + if ((int)blockIdx.x == 0) { + if ((int)threadIdx.x == 0) { + printf("\n%s:\n", name); + } + for (int tid = 0; tid < total_threads; tid++) { + __syncthreads(); + if (threadIdx.x == tid) { + printf("tid: %d: ", tid); + for (int i = 0; i < range; i++) { + printf("%f ", float(reg[i])); + } + printf("\n"); + } + } + } + __syncthreads(); +} + + +extern "C" __global__ void __launch_bounds__(128) main_kernel_no_tma(__grid_constant__ const CUtensorMap K_desc, half_t* __restrict__ Output, __grid_constant__ const CUtensorMap Q_desc, __grid_constant__ const CUtensorMap V_desc) { +} + + +struct TensorMapArgs { + CUtensorMap* map; + CUtensorMapDataType type; + cuuint32_t tensorRank; + void* globalAddress; + cuuint64_t globalDim[5], globalStride[5]; + cuuint32_t boxDim[5], elementStrides[5]; + CUtensorMapInterleave interleave; + CUtensorMapSwizzle swizzle; + CUtensorMapL2promotion l2Promotion; + CUtensorMapFloatOOBfill oobFill; + + std::string ToDebugString() { + std::stringstream ss; + // ss << "TMA Desc Addr: " << map << std::endl + // << "format " << type << std::endl + // << "dim " << tensorRank << std::endl + // << "gmem_address " << globalAddress << std::endl + // << "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl + // << "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl + // << "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl + // << "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl + // << "interleave " << interleave << std::endl + // << "swizzle " << swizzle << std::endl + // << "l2Promotion " << l2Promotion << std::endl + // << "oobFill " << oobFill << std::endl; + return ss.str(); + } +}; + +void host_function_no_tma(Flash_fwd_params params) { + int num_m_blocks = (params.seq_len + params.block_M - 1) / params.block_M; + dim3 grid(num_m_blocks, params.head, params.batch); + dim3 block(128); + size_t sharedMemSize = (params.block_M + 2 * params.block_N) * params.dim * sizeof(half_t); // 24576; + + // int size = params.batch * params.head * params.seq_len * params.dim * sizeof(half_t); + + CUtensorMap Q_desc = {0}; + CUtensorMap K_desc = {0}; + CUtensorMap V_desc = {0}; + TensorMapArgs Q_arg; + TensorMapArgs K_arg; + TensorMapArgs V_arg; + + Q_arg.map = &Q_desc; + Q_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + Q_arg.tensorRank = 4; + Q_arg.globalAddress = params.q_ptr; + Q_arg.globalDim[0] = static_cast(params.dim); + Q_arg.globalDim[1] = static_cast(params.head); + Q_arg.globalDim[2] = static_cast(params.seq_len); + Q_arg.globalDim[3] = static_cast(params.batch); + Q_arg.globalStride[0] = static_cast(2); + Q_arg.globalStride[1] = static_cast(128); + Q_arg.globalStride[2] = static_cast(128); + Q_arg.globalStride[3] = static_cast(32768); + Q_arg.boxDim[0] = static_cast(64); + Q_arg.boxDim[1] = static_cast(1); + Q_arg.boxDim[2] = static_cast(64); + Q_arg.boxDim[3] = static_cast(1); + Q_arg.elementStrides[0] = static_cast(1); + Q_arg.elementStrides[1] = static_cast(1); + Q_arg.elementStrides[2] = static_cast(1); + Q_arg.elementStrides[3] = static_cast(1); + Q_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + Q_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + Q_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + Q_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + K_arg.map = &K_desc; + K_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + K_arg.tensorRank = 4; + K_arg.globalAddress = params.k_ptr; + K_arg.globalDim[0] = static_cast(64); + K_arg.globalDim[1] = static_cast(1); + K_arg.globalDim[2] = static_cast(256); + K_arg.globalDim[3] = static_cast(1); + K_arg.globalStride[0] = static_cast(2); + K_arg.globalStride[1] = static_cast(128); + K_arg.globalStride[2] = static_cast(128); + K_arg.globalStride[3] = static_cast(32768); + K_arg.boxDim[0] = static_cast(64); + K_arg.boxDim[1] = static_cast(1); + K_arg.boxDim[2] = static_cast(64); + K_arg.boxDim[3] = static_cast(1); + K_arg.elementStrides[0] = static_cast(1); + K_arg.elementStrides[1] = static_cast(1); + K_arg.elementStrides[2] = static_cast(1); + K_arg.elementStrides[3] = static_cast(1); + K_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + K_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + K_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + K_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + V_arg.map = &V_desc; + V_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + V_arg.tensorRank = 4; + V_arg.globalAddress = params.v_ptr; + V_arg.globalDim[0] = static_cast(64); + V_arg.globalDim[1] = static_cast(1); + V_arg.globalDim[2] = static_cast(256); + V_arg.globalDim[3] = static_cast(1); + V_arg.globalStride[0] = static_cast(2); + V_arg.globalStride[1] = static_cast(128); + V_arg.globalStride[2] = static_cast(128); + V_arg.globalStride[3] = static_cast(32768); + V_arg.boxDim[0] = static_cast(64); + V_arg.boxDim[1] = static_cast(1); + V_arg.boxDim[2] = static_cast(64); + V_arg.boxDim[3] = static_cast(1); + V_arg.elementStrides[0] = static_cast(1); + V_arg.elementStrides[1] = static_cast(1); + V_arg.elementStrides[2] = static_cast(1); + V_arg.elementStrides[3] = static_cast(1); + V_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + V_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + V_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + V_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + CUresult result; + result = cuTensorMapEncodeTiled( + Q_arg.map, Q_arg.type, Q_arg.tensorRank, Q_arg.globalAddress, Q_arg.globalDim, Q_arg.globalStride + 1, Q_arg.boxDim, + Q_arg.elementStrides, Q_arg.interleave, Q_arg.swizzle, Q_arg.l2Promotion, Q_arg.oobFill); + if (result != CUDA_SUCCESS) { + std::cout << "Failed to initialize the TMA descriptor " << result << std::endl + << Q_arg.ToDebugString(); + } + + result = cuTensorMapEncodeTiled( + K_arg.map, K_arg.type, K_arg.tensorRank, K_arg.globalAddress, K_arg.globalDim, K_arg.globalStride + 1, K_arg.boxDim, + K_arg.elementStrides, K_arg.interleave, K_arg.swizzle, K_arg.l2Promotion, K_arg.oobFill); + if (result != CUDA_SUCCESS) { + std::cout << "Failed to initialize the TMA descriptor " << result << std::endl + << K_arg.ToDebugString(); + } + + result = cuTensorMapEncodeTiled( + V_arg.map, V_arg.type, V_arg.tensorRank, V_arg.globalAddress, V_arg.globalDim, V_arg.globalStride + 1, V_arg.boxDim, + V_arg.elementStrides, V_arg.interleave, V_arg.swizzle, V_arg.l2Promotion, V_arg.oobFill); + if (result != CUDA_SUCCESS) { + std::cout << "Failed to initialize the TMA descriptor " << result << std::endl + << V_arg.ToDebugString(); + } + + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "CUDA device synchronization failed: " << cudaGetErrorString(err) << std::endl; + return; + } + + main_kernel_no_tma<<>>(K_desc, (half_t*)params.output_ptr, Q_desc, V_desc); + + err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "CUDA kernel launch failed: " << cudaGetErrorString(err) << std::endl; + return; + } + + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "CUDA device synchronization failed: " << cudaGetErrorString(err) << std::endl; + return; + } +} \ No newline at end of file diff --git a/tl_verify/main.py b/tl_verify/main.py new file mode 100644 index 000000000000..0d00d5307af2 --- /dev/null +++ b/tl_verify/main.py @@ -0,0 +1,69 @@ +import torch +import fa_test +# from flash_attn.flash_attn_interface import flash_attn_func +import random +import numpy as np + +def set_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + np.random.seed(seed) # Numpy module. + random.seed(seed) # Python random module. + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def ref_program(Q, K, V, casual): + # from flash_attn.flash_attn_interface import flash_attn_func + + # return flash_attn_func(Q, K, V, causal=casual) + assert casual == False, "casual is not supported" + batch, seq_len, heads, dim = Q.size() + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + block_M = seq_len + block_N = 64 if dim <= 128 else 32 + acc_s = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float) + acc_s_cast = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float16) + acc_o = torch.empty((batch, block_M, heads, dim), device="cuda", dtype=torch.float) + scores_max = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) + scores_max_prev = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) + scores_scale = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) + scores_sum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) + logsum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) + acc_o.fill_(0) + logsum.fill_(0) + scores_max.fill_(float('-inf')) + Q_scaled = Q * scale + + for i in range(int(seq_len / block_N)): + acc_s.fill_(0) + acc_s = torch.einsum('bqhd,bkhd->bhqk', Q_scaled, K[:, i * block_N : (i + 1) * block_N, :, :]) # [batch, seqlen, heads, block_N] + # scores_max_prev = scores_max + # scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] + # scores_scale = torch.exp2(scores_max_prev - scores_max) + # acc_o *= scores_scale[:, :, :, None].transpose(1, 2) + acc_s = torch.exp2(acc_s - 32) + acc_s_cast = acc_s.to(torch.float16) + acc_o += torch.einsum('bhqk,bkhd->bqhd', acc_s_cast, V[:, i * block_N : (i + 1) * block_N, :, :]) + # scores_sum = acc_s.sum(dim=-1, keepdim=False) + # logsum = logsum * scores_scale + scores_sum + # acc_o /= logsum[:, :, :, None].transpose(1, 2) + return acc_o.to(torch.float16) + +set_seed(42) +batch, seq_len, heads, dim = 1, 256, 1, 64 +shape = [batch, seq_len, heads, dim] +# q = torch.empty(*shape, device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) +# k = torch.empty(*shape, device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) +q = torch.ones(*shape, device='cuda', dtype=torch.float16) +k = torch.ones(*shape, device='cuda', dtype=torch.float16) +v = torch.empty(*shape, device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) + +# output = fa_test.kernel_function(q, k, v) +output = fa_test.kernel_function_no_tma(q, k, v) +# ref_output = flash_attn_func(q, k, v, causal=False) +ref_output = ref_program(q, k, v, False) +print(output) +print(ref_output) +# print(ref_output) +# assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2) \ No newline at end of file diff --git a/tl_verify/setup.py b/tl_verify/setup.py new file mode 100644 index 000000000000..d38ed6637c21 --- /dev/null +++ b/tl_verify/setup.py @@ -0,0 +1,49 @@ +from setuptools import setup +import torch.utils.cpp_extension +from torch.utils.cpp_extension import CUDAExtension, BuildExtension +torch.utils.cpp_extension.CUDAExtension.debug = True + +extra_compile_args = { + 'cxx': ['-O3', '-std=c++17'], + 'nvcc': [ + '-arch=sm_90a', + '--use_fast_math', + '-std=c++17', + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + '-I/usr/local/cuda/include', + '-I/home/msra/cy/tvm.tl/src/tl', + '-I/home/msra/cy/tvm.tl/cutlass/include', + '-lcuda', + # '-keep' # Uncomment this line to keep the generated .ptx file + ], +} + +include_dirs = [ + '/home/msra/cy/tvm.tl/src/tl', + '/home/msra/cy/tvm.tl/cutlass/include', + '/usr/local/cuda/include' +] + +setup( + name='fa_test', + ext_modules=[ + CUDAExtension( + 'fa_test', + sources=['cuda_interface.cpp', 'fa_kernel.cu', 'fa_no_tma.cu'], + extra_compile_args=extra_compile_args, + include_dirs=include_dirs, + libraries=["cuda"] + ), + ], + cmdclass={ + 'build_ext': BuildExtension + } +) From 5971c954eee522dd51240a0565e89b3fad249f43 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Sun, 1 Sep 2024 17:16:54 +0000 Subject: [PATCH 12/23] [tl] Fix mha bugs --- tl_scripts/mha_example.py | 9 ++-- tl_scripts/mha_pipeline.py | 106 +++++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 5 deletions(-) create mode 100644 tl_scripts/mha_pipeline.py diff --git a/tl_scripts/mha_example.py b/tl_scripts/mha_example.py index 602d162dad05..486b90ce6f05 100644 --- a/tl_scripts/mha_example.py +++ b/tl_scripts/mha_example.py @@ -43,7 +43,6 @@ def main( ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) - Q_local = T.alloc_fragment([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) @@ -60,9 +59,6 @@ def main( T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - T.copy(Q_shared, Q_local) - for i, j in T.Parallel(block_M, dim): - Q_local[i, j] *= scale loop_range = ( T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N) ) @@ -75,9 +71,12 @@ def main( ) else: T.clear(acc_s) - T.gemm(Q_local, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + for i, j in T.Parallel(block_M, dim): + acc_s[i, j] *= scale T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) diff --git a/tl_scripts/mha_pipeline.py b/tl_scripts/mha_pipeline.py new file mode 100644 index 000000000000..e447a9002ada --- /dev/null +++ b/tl_scripts/mha_pipeline.py @@ -0,0 +1,106 @@ +import torch +from tvm import tl +import tvm.tl.language as T +from functools import partial + + +def flashattn(batch, heads, seq_len, dim, is_casual, block_M, block_N): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def main( + Q: T.Buffer(shape, dtype), + K: T.Buffer(shape, dtype), + V: T.Buffer(shape, dtype), + Output: T.Buffer(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + # Q_local = T.alloc_fragment([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + # T.copy(Q_shared, Q_local) + # for i, j in T.Parallel(block_M, dim): + # Q_local[i, j] *= scale + loop_range = ( + T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N) + ) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_casual: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype) + ) + else: + T.clear(acc_s) + # T.gemm(Q_local, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + for i, j in T.Parallel(block_M, dim): + acc_s[i, j] *= scale + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + + return main + + +def ref_program(Q, K, V, casual): + from flash_attn.flash_attn_interface import flash_attn_func + + return flash_attn_func(Q, K, V, causal=casual) + + +if __name__ == "__main__": + BATCH, H, N_CTX, D_HEAD = 64, 12, 2048, 64 + casual = True + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 2 * flops_per_matmul + if casual: + total_flops *= 0.5 + BLOCK_M = 64 + BLOCK_N = 64 if D_HEAD <= 128 else 32 + program = flashattn(BATCH, H, N_CTX, D_HEAD, casual, BLOCK_M, BLOCK_N) + ref_program = partial(ref_program, casual=casual) + mod, params = tl.lower(program) + mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) + mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) + + latency = mod.do_bench(ref_program, warmup=500) + print("{:.2f} ms".format(latency)) + print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = mod.do_bench(mod) + print("{:.2f} ms".format(latency)) + print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) \ No newline at end of file From 191c715512da89f3bcbc93cb5ca893b04fd08ca9 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Mon, 2 Sep 2024 08:25:06 +0000 Subject: [PATCH 13/23] [tl] Update mha_pipeline, correctness done. --- tl_scripts/mha_pipeline.py | 132 +++++++++++++++++++++++++++---------- 1 file changed, 99 insertions(+), 33 deletions(-) diff --git a/tl_scripts/mha_pipeline.py b/tl_scripts/mha_pipeline.py index e447a9002ada..d515427be6f6 100644 --- a/tl_scripts/mha_pipeline.py +++ b/tl_scripts/mha_pipeline.py @@ -3,6 +3,37 @@ import tvm.tl.language as T from functools import partial +# Codegen bug: +# LoadK should wait for MMA0 done +# @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) +# def tvm_callback_cuda_postproc(code, _): +# code = code.replace("""tl::mbarrier_wait(_mbarrier[1], ((k & 1) ^ 1));""", +# """tl::mbarrier_wait(_mbarrier[1], ((k & 1))); // replace""") +# code = code.replace("""tl::gemm_ss<64, 64, 64, 4, 1, 0, 1>((&(((half_t*)buf_dyn_shmem)[0])), (&(((half_t*)buf_dyn_shmem)[4096])), (&(acc_s[0]))); +# #pragma unroll""", +# """tl::gemm_ss<64, 64, 64, 4, 1, 0, 1>((&(((half_t*)buf_dyn_shmem)[0])), (&(((half_t*)buf_dyn_shmem)[4096])), (&(acc_s[0]))); +# tl::mbarrier_arrive(_mbarrier[1]); +# #pragma unroll // replace""") +# return code + +# loadk(0) +# gemm0(0) +# loadk(1) +# softmax(0) +# loadv(0) + +# for i in range(loop_range - 2): +# gemm0(i+1) +# gemm1(i+0) +# loadk(i+2) +# softmax(i+1) +# loadv(i+1) + +# gemm0(loop_range - 1) +# gemm1(loop_range - 2) +# softmax(loop_range - 1) +# loadv(loop_range - 1) +# gemm1(loop_range - 1) def flashattn(batch, heads, seq_len, dim, is_casual, block_M, block_N): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) @@ -10,6 +41,60 @@ def flashattn(batch, heads, seq_len, dim, is_casual, block_M, block_N): dtype = "float16" accum_dtype = "float" + @T.macro + def MMA0( + K: T.Buffer(shape, dtype), + Q_shared: T.Buffer([block_M, dim], dtype), + K_shared: T.Buffer([block_N, dim], dtype), + acc_s: T.Buffer([block_M, block_N], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Buffer(shape, dtype), + V_shared: T.Buffer([block_M, dim], dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + acc_o: T.Buffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.Buffer([block_M, block_N], accum_dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + acc_o: T.Buffer([block_M, dim], accum_dtype), + scores_max: T.Buffer([block_M], accum_dtype), + scores_max_prev: T.Buffer([block_M], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + scores_sum: T.Buffer([block_M], accum_dtype), + logsum: T.Buffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_s[i, j] *= scale + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + T.copy(acc_s, acc_s_cast) + @T.prim_func def main( Q: T.Buffer(shape, dtype), @@ -19,7 +104,6 @@ def main( ): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) - # Q_local = T.alloc_fragment([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) @@ -36,40 +120,22 @@ def main( T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - # T.copy(Q_shared, Q_local) - # for i, j in T.Parallel(block_M, dim): - # Q_local[i, j] *= scale + + MMA0(K, Q_shared, K_shared, acc_s, 0, by, bz) + Softmax(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + loop_range = ( T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N) ) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) - if is_casual: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else( - bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype) - ) - else: - T.clear(acc_s) - # T.gemm(Q_local, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) - for i, j in T.Parallel(block_M, dim): - acc_s[i, j] *= scale - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) - T.copy(acc_s, acc_s_cast) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + if k < loop_range - 1: + MMA0(K, Q_shared, K_shared, acc_s, k + 1, by, bz) + + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + + if k < loop_range - 1: + Softmax(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) @@ -84,8 +150,8 @@ def ref_program(Q, K, V, casual): if __name__ == "__main__": - BATCH, H, N_CTX, D_HEAD = 64, 12, 2048, 64 - casual = True + BATCH, H, N_CTX, D_HEAD = 64, 16, 4096, 64 + casual = False flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD total_flops = 2 * flops_per_matmul if casual: From 1fd50fa122277724817a68d3f887b9b62efe9cdd Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Sun, 8 Sep 2024 12:19:06 +0000 Subject: [PATCH 14/23] [tl] Add pipeline for consumers. Correct on mha_pipeline.py --- .gitignore | 5 +- python/tvm/tl/engine.py | 22 +- python/tvm/tl/language.py | 4 +- python/tvm/tl/transform.py | 23 + src/tir/transforms/lower_opaque_block.cc | 8 + src/tl/ir.cc | 5 +- src/tl/transform/inject_mbarrier.cc | 242 +++ .../multi_version_buffer_rewriter.cc | 321 ++++ src/tl/transform/warp_specialized_pipeline.cc | 1698 ++++++++--------- src/tl/transform/warp_specialized_rewriter.cc | 738 +++++++ tl_scripts/mha_pipeline.py | 25 +- 11 files changed, 2220 insertions(+), 871 deletions(-) create mode 100644 src/tl/transform/inject_mbarrier.cc create mode 100644 src/tl/transform/multi_version_buffer_rewriter.cc create mode 100644 src/tl/transform/warp_specialized_rewriter.cc diff --git a/.gitignore b/.gitignore index 921890a62759..55718420119d 100644 --- a/.gitignore +++ b/.gitignore @@ -279,5 +279,6 @@ gallery/how_to/work_with_microtvm/micro_tvmc.py .gdb_history cmake/config.cmake -reports/* -play.py \ No newline at end of file +*/reports/* +play.py +*.ptx \ No newline at end of file diff --git a/python/tvm/tl/engine.py b/python/tvm/tl/engine.py index 70ac35db8124..3e2eaacef4de 100644 --- a/python/tvm/tl/engine.py +++ b/python/tvm/tl/engine.py @@ -113,7 +113,27 @@ def lower(func): # print(mod) if target.arch == "sm_90": - mod = tl.transform.WarpSpecializedPipeline()(mod) + mod = tl.transform.MultiVersionBuffer()(mod) + + # print('-'*100 + '\n' + 'after MultiVersionBuffer\n' + '-'*100) + # print(mod) + + mod = tl.transform.WarpSpecialized()(mod) + + # print('-'*100 + '\n' + 'after WarpSpecialized\n' + '-'*100) + # print(mod) + + mod = tl.transform.InjectSoftwarePipeline()(mod) + + # print('-'*100 + '\n' + 'after InjectSoftwarePipeline\n' + '-'*100) + # print(mod) + + mod = tir.transform.LowerOpaqueBlock()(mod) + + # print('-'*100 + '\n' + 'after LowerOpaqueBlock\n' + '-'*100) + # print(mod) + + # mod = tl.transform.WarpSpecializedPipeline()(mod) # print('-'*100 + '\n' + 'after WarpSpecializedPipeline\n' + '-'*100) # print(mod) diff --git a/python/tvm/tl/language.py b/python/tvm/tl/language.py index d049c817b292..7b0a098ee849 100644 --- a/python/tvm/tl/language.py +++ b/python/tvm/tl/language.py @@ -43,7 +43,7 @@ def Parallel(*extents: tir.PrimExpr): return _ffi_api.Parallel(extents) # type: ignore[attr-defined] # pylint: disable=no-member -def Pipelined(start: tir.PrimExpr, stop: tir.PrimExpr = None, num_stages: int = 0): +def Pipelined(start: tir.PrimExpr, stop: tir.PrimExpr = None, num_stages: int = 0, order: List[int] = [], stage: List[int] = [], group: List[List[int]] = []): """Tools to construct pipelined for loop. Parameters @@ -67,7 +67,7 @@ def Pipelined(start: tir.PrimExpr, stop: tir.PrimExpr = None, num_stages: int = else: start = 0 # type: ignore[attr-defined] # pylint: disable=no-member - return _ffi_api.Pipelined(start, stop, num_stages) + return _ffi_api.Pipelined(start, stop, num_stages, order, stage, group) @register_object("tl.KernelLaunchFrame") diff --git a/python/tvm/tl/transform.py b/python/tvm/tl/transform.py index f040532cbdac..9727799ee2da 100644 --- a/python/tvm/tl/transform.py +++ b/python/tvm/tl/transform.py @@ -107,6 +107,29 @@ def WarpSpecializedPipeline(): """ return _ffi_api.WarpSpecializedPipeline() # type: ignore + +def MultiVersionBuffer(): + """WarpSpecializedPipeline + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.MultiVersionBuffer() # type: ignore + + +def WarpSpecialized(): + """WarpSpecializedPipeline + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.WarpSpecialized() # type: ignore + + def InjectFenceProxy(): """InjectFenceProxy diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc index 86892433b42d..a381cf2f9fd5 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/tir/transforms/lower_opaque_block.cc @@ -82,6 +82,14 @@ class OpaqueBlockLower : public StmtExprMutator { return body; } + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + if (block->annotations.count("stmt_group")) { + return block->body; + } + return block; + } + Stmt VisitStmt_(const ForNode* op) final { // Step 1. Update unit loop info. PrimExpr min = this->VisitExpr(op->min); diff --git a/src/tl/ir.cc b/src/tl/ir.cc index 98462bde6dbe..5d440773e4a1 100644 --- a/src/tl/ir.cc +++ b/src/tl/ir.cc @@ -54,7 +54,7 @@ ForFrame ParallelFor(Array extents) { return ForFrame(n); } -ForFrame PipelinedFor(PrimExpr start, PrimExpr stop, int num_stages) { +ForFrame PipelinedFor(PrimExpr start, PrimExpr stop, int num_stages, Array order, Array stages, Array> groups) { using namespace tvm::tir; ObjectPtr n = make_object(); DataType dtype = stop.dtype(); @@ -66,6 +66,9 @@ ForFrame PipelinedFor(PrimExpr start, PrimExpr stop, int num_stages) { ICHECK(n == 1); Map anno; if (num_stages > 0) anno.Set("num_stages", PrimExpr(num_stages)); + anno.Set("software_pipeline_order", order); + anno.Set("software_pipeline_stage", stages); + anno.Set("software_pipeline_group", groups); body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, std::move(body), /*thread_binding=*/NullOpt, /*annotations=*/anno); return body; diff --git a/src/tl/transform/inject_mbarrier.cc b/src/tl/transform/inject_mbarrier.cc new file mode 100644 index 000000000000..c053c631927a --- /dev/null +++ b/src/tl/transform/inject_mbarrier.cc @@ -0,0 +1,242 @@ +// /* +// * Licensed to the Apache Software Foundation (ASF) under one +// * or more contributor license agreements. See the NOTICE file +// * distributed with this work for additional information +// * regarding copyright ownership. The ASF licenses this file +// * to you under the Apache License, Version 2.0 (the +// * "License"); you may not use this file except in compliance +// * with the License. You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, +// * software distributed under the License is distributed on an +// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// * KIND, either express or implied. See the License for the +// * specific language governing permissions and limitations +// * under the License. +// */ + +// /*! +// * \file warp_specialized_pipeline.cc +// * \brief Warp specialized Pipeline for cuda GPU (sm90+) +// */ + +// #include +// #include +// #include +// #include +// #include + +// #include "../op/builtin.h" + +// namespace tvm { +// namespace tl { + +// using namespace tir; + +// enum class Role { kConsumer, kProducer, kBoth }; + +// class WarpSpecializedRoleMarker : public StmtVisitor { +// public: +// WarpSpecializedRoleMarker(Map buffer_data_to_buffer) +// : buffer_data_to_buffer_(buffer_data_to_buffer) {} + +// Role GetRole(const StmtNode* stmt) const { +// auto it = map_.find(stmt); +// ICHECK(it != map_.end()); +// return it->second; +// } + +// Role GetRole(const Stmt& stmt) const { return GetRole(stmt.get()); } + +// void VisitStmt_(const EvaluateNode* op) final { +// Role role = Role::kConsumer; +// if (auto call = op->value.as()) { +// if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { +// role = Role::kProducer; +// has_bulk_copy_ = true; +// } +// } +// SetRole(op, role); +// } + +// void VisitStmt_(const BufferStoreNode* op) final { +// bool is_shared_store = op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared"; +// if (!is_shared_store) { +// SetRole(op, Role::kConsumer); +// return; +// } + +// // Check reads from global +// Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", +// /*body*/ GetRef(op)); +// auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); +// auto reads = access[0]; +// Role role = Role::kProducer; +// for (auto read : reads) { +// if (read->buffer.scope() != "global") { +// role = Role::kConsumer; +// break; +// } +// } +// if (role == Role::kProducer) has_simt_copy_ = true; +// SetRole(op, role); +// } + +// void VisitStmt_(const SeqStmtNode* op) final { +// StmtVisitor::VisitStmt_(op); +// auto role = GetRole(op->seq[0]); +// for (auto stmt : op->seq) { +// if (role != GetRole(stmt)) { +// role = Role::kBoth; +// break; +// } +// } +// SetRole(op, role); +// } + +// void VisitStmt_(const IfThenElseNode* op) final { +// StmtVisitor::VisitStmt_(op); +// auto role = GetRole(op->then_case); +// if (op->else_case.defined()) { +// auto role_else = GetRole(op->else_case.value()); +// if (role != role_else) role = Role::kBoth; +// } +// SetRole(op, role); +// } + +// void VisitStmt_(const BlockRealizeNode* op) final { +// StmtVisitor::VisitStmt_(op); +// SetRole(op, GetRole(op->block)); +// } + +// template +// void HandleBodyStmt(const NodeType* op) { +// StmtVisitor::VisitStmt_(op); +// SetRole(op, GetRole(op->body)); +// } + +// void VisitStmt_(const ForNode* op) final { HandleBodyStmt(op); } +// void VisitStmt_(const LetStmtNode* op) final { HandleBodyStmt(op); } +// void VisitStmt_(const AttrStmtNode* op) final { HandleBodyStmt(op); } +// void VisitStmt_(const AssertStmtNode* op) final { HandleBodyStmt(op); } +// void VisitStmt_(const BlockNode* op) final { HandleBodyStmt(op); } + +// bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; } + +// bool HasSimtCopy() { return has_simt_copy_; } + +// private: +// void SetRole(const StmtNode* stmt, Role role) { map_[stmt] = role; } +// Map buffer_data_to_buffer_; +// std::unordered_map map_; +// bool has_simt_copy_ = false; +// bool has_bulk_copy_ = false; +// }; + +// static PrimExpr makeGetBarrier(PrimExpr barrier_id) { +// return Call(DataType::Handle(), GetMBarrierOp(), {barrier_id}); +// } + +// static Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) { +// auto call = Call(DataType::Handle(), MBarrierExpectTX(), {makeGetBarrier(barrier_id), bytes}); +// return Evaluate(call); +// } + +// static Stmt makeArriveBarrier(PrimExpr barrier_id) { +// auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(), {makeGetBarrier(barrier_id)}); +// return Evaluate(call); +// } + +// static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) { +// auto call = +// Call(DataType::Handle(), builtin::ptx_cp_async_barrier(), {makeGetBarrier(barrier_id)}); +// return Evaluate(call); +// } + +// static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) { +// auto call = Call(DataType::Handle(), MBarrierWaitParity(), {makeGetBarrier(barrier_id), parity}); +// return Evaluate(call); +// } + +// class ProducerTraitsCollector : public StmtExprVisitor { +// public: +// ProducerTraitsCollector() { Clear(); } + +// void Clear() { +// bulk_copy_bytes = 0; +// loop_extents = 1; +// has_simt_copy = false; +// } + +// void Collect(Stmt stmt) { VisitStmt(stmt); } + +// bool HasSimtCopy() { return has_simt_copy; } + +// PrimExpr BulkCopyBytes() { return bulk_copy_bytes; } + +// private: +// void VisitExpr_(const CallNode* call) final { +// if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { +// Call access_ptr = Downcast(call->args[2]); +// ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); +// int type_bytes = access_ptr->args[0]->dtype.bytes(); +// bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes; +// } +// StmtExprVisitor::VisitExpr_(call); +// } + +// void VisitStmt_(const ForNode* op) final { +// PrimExpr old_loop_evtents = loop_extents; +// loop_extents *= op->extent; +// StmtExprVisitor::VisitStmt_(op); +// loop_extents = old_loop_evtents; +// } + +// void VisitExpr_(const BufferLoadNode* op) final { +// has_simt_copy = true; +// StmtExprVisitor::VisitExpr_(op); +// } + +// bool has_simt_copy; +// PrimExpr bulk_copy_bytes; +// PrimExpr loop_extents; +// }; + +// // Rewrite the producer Stmt to use the correct barrier index +// class MbarrierRewriter : public StmtExprMutator { +// public: +// static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) { +// MbarrierRewriter rewriter; +// rewriter.producer_barrier_idx_ = barrier_id; +// return rewriter(stmt); +// } + +// private: +// PrimExpr VisitExpr_(const CallNode* op) final { +// auto call = Downcast(StmtExprMutator::VisitExpr_(op)); +// if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { +// Call access_ptr = Downcast(call->args[2]); +// ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); +// call.CopyOnWrite()->args.Set(1, makeGetBarrier(producer_barrier_idx_)); +// } +// return call; +// } +// PrimExpr producer_barrier_idx_; +// }; + + +// using namespace tir::transform; + +// tvm::transform::Pass InjectMbarrier() { +// auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { +// return WarpSpecializedRewriter::Substitute(f); +// }; +// return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {}); +// } + +// TVM_REGISTER_GLOBAL("tl.InjectMbarrier").set_body_typed(InjectMbarrier); + +// } // namespace tl +// } // namespace tvm diff --git a/src/tl/transform/multi_version_buffer_rewriter.cc b/src/tl/transform/multi_version_buffer_rewriter.cc new file mode 100644 index 000000000000..64d484ff9f5e --- /dev/null +++ b/src/tl/transform/multi_version_buffer_rewriter.cc @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file warp_specialized_pipeline.cc + * \brief Warp specialized Pipeline for cuda GPU (sm90+) + */ + +#include +#include +#include +#include +#include + +#include "../op/builtin.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +enum class Role { kConsumer, kProducer, kBoth }; + +class WarpSpecializedRoleMarker_ : public StmtVisitor { + public: + WarpSpecializedRoleMarker_(Map buffer_data_to_buffer) + : buffer_data_to_buffer_(buffer_data_to_buffer) {} + + Role GetRole(const StmtNode* stmt) const { + auto it = map_.find(stmt); + ICHECK(it != map_.end()); + return it->second; + } + + Role GetRole(const Stmt& stmt) const { return GetRole(stmt.get()); } + + void VisitStmt_(const EvaluateNode* op) final { + Role role = Role::kConsumer; + if (auto call = op->value.as()) { + if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { + role = Role::kProducer; + has_bulk_copy_ = true; + } + } + SetRole(op, role); + } + + void VisitStmt_(const BufferStoreNode* op) final { + bool is_shared_store = op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared"; + if (!is_shared_store) { + SetRole(op, Role::kConsumer); + return; + } + + // Check reads from global + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", + /*body*/ GetRef(op)); + auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + auto reads = access[0]; + Role role = Role::kProducer; + for (auto read : reads) { + if (read->buffer.scope() != "global") { + role = Role::kConsumer; + break; + } + } + if (role == Role::kProducer) has_simt_copy_ = true; + SetRole(op, role); + } + + void VisitStmt_(const SeqStmtNode* op) final { + StmtVisitor::VisitStmt_(op); + auto role = GetRole(op->seq[0]); + for (auto stmt : op->seq) { + if (role != GetRole(stmt)) { + role = Role::kBoth; + break; + } + } + SetRole(op, role); + } + + void VisitStmt_(const IfThenElseNode* op) final { + StmtVisitor::VisitStmt_(op); + auto role = GetRole(op->then_case); + if (op->else_case.defined()) { + auto role_else = GetRole(op->else_case.value()); + if (role != role_else) role = Role::kBoth; + } + SetRole(op, role); + } + + void VisitStmt_(const BlockRealizeNode* op) final { + StmtVisitor::VisitStmt_(op); + SetRole(op, GetRole(op->block)); + } + + template + void HandleBodyStmt(const NodeType* op) { + StmtVisitor::VisitStmt_(op); + SetRole(op, GetRole(op->body)); + } + + void VisitStmt_(const ForNode* op) final { HandleBodyStmt(op); } + void VisitStmt_(const LetStmtNode* op) final { HandleBodyStmt(op); } + void VisitStmt_(const AttrStmtNode* op) final { HandleBodyStmt(op); } + void VisitStmt_(const AssertStmtNode* op) final { HandleBodyStmt(op); } + void VisitStmt_(const BlockNode* op) final { HandleBodyStmt(op); } + + bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; } + + bool HasSimtCopy() { return has_simt_copy_; } + + private: + void SetRole(const StmtNode* stmt, Role role) { map_[stmt] = role; } + Map buffer_data_to_buffer_; + std::unordered_map map_; + bool has_simt_copy_ = false; + bool has_bulk_copy_ = false; +}; + +class MultiVersionBufferRewriter : public StmtExprMutator { + public: + static PrimFunc Substitute(PrimFunc& f) { + auto rewriter = MultiVersionBufferRewriter(); + rewriter.buffer_lca_ = DetectBufferAccessLCA(f); + for (auto [buffer, _] : rewriter.buffer_lca_) { + Var buffer_var = buffer->data; + rewriter.buffer_data_to_buffer_.Set(buffer_var, buffer); + } + f.CopyOnWrite()->body = rewriter(f->body); + return f; + } + + private: + MultiVersionBufferRewriter() = default; + + Array GetVersionedBuffers(Array seq_stmt, Array scoped_buffers) { + std::vector roles; + Array> reads, writes; + auto marker = WarpSpecializedRoleMarker_(buffer_data_to_buffer_); + for (auto stmt : seq_stmt) { + marker(stmt); + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ stmt); + auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_); + reads.push_back(std::move(access[0])); + writes.push_back(std::move(access[1])); + roles.push_back(marker.GetRole(stmt)); + } + + std::unordered_set consumer_used, producer_used; + for (size_t i = 0; i < seq_stmt.size(); i++) { + if (roles[i] == Role::kProducer) { + for (BufferRegion br : writes[i]) producer_used.insert(br->buffer.get()); + } else { + for (BufferRegion br : reads[i]) consumer_used.insert(br->buffer.get()); + } + } + Array versioned_buffers; + for (Buffer buffer : scoped_buffers) { + if (consumer_used.count(buffer.get()) && producer_used.count(buffer.get())) { + versioned_buffers.push_back(buffer); + } + } + return versioned_buffers; + } + + static Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) { + ObjectPtr new_buffer = make_object(*(buffer.get())); + new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); + if (new_buffer->strides.size()) { + ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); + PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1]; + new_buffer->strides.insert(new_buffer->strides.begin(), stride_0); + } + return Buffer(new_buffer); + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + BlockRealize block_realize = Downcast(StmtExprMutator::VisitStmt_(op)); + Block block = block_realize->block; + Array alloc_buffers; + for (auto buffer : block->alloc_buffers) { + if (buffer_remap_.count(buffer)) { + Buffer new_buffer = buffer_remap_[buffer]; + alloc_buffers.push_back(new_buffer); + } else { + alloc_buffers.push_back(buffer); + } + } + block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers); + block_realize.CopyOnWrite()->block = block; + return block_realize; + } + + Stmt VisitStmt_(const ForNode* op) final { + auto num_stages_anno = op->annotations.Get("num_stages"); + if (!num_stages_anno.defined()) return StmtExprMutator::VisitStmt_(op); + + ICHECK(num_stages_anno.as()); + int num_stages = static_cast(num_stages_anno.as()->value); + + const SeqStmtNode* pipeline_body_seq = op->body.as(); + CHECK(pipeline_body_seq) + << "ValueError: The body of the software pipeline should be SeqStmt, got " + << op->body->GetTypeKey(); + + Array scoped_buffers = {}; + for (auto [buffer, stmt] : buffer_lca_) { + if (stmt.defined() && stmt.value().get() == op) scoped_buffers.push_back(buffer); + } + + Array versioned_buffers = GetVersionedBuffers(pipeline_body_seq->seq, scoped_buffers); + + for (auto buffer : versioned_buffers) { + Var buffer_var = buffer->data; + Buffer new_buffer = RewriteAllocBuffer(buffer, num_stages); + buffer_remap_.Set(buffer, new_buffer); + } + version_index_ = FloorMod(op->loop_var - op->min, num_stages); + auto for_node = StmtExprMutator::VisitStmt_(op); + + return for_node; + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto it = buffer_remap_.find(load->buffer); + if (it == buffer_remap_.end()) { + return std::move(load); + } + const Buffer& new_buffer = (*it).second; + auto* n = load.CopyOnWrite(); + n->buffer = new_buffer; + n->indices.insert(n->indices.begin(), version_index_); + return std::move(load); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto it = buffer_remap_.find(store->buffer); + if (it == buffer_remap_.end()) { + return std::move(store); + } + const Buffer& new_buffer = (*it).second; + auto* n = store.CopyOnWrite(); + n->buffer = new_buffer; + n->indices.insert(n->indices.begin(), version_index_); + return std::move(store); + } + + PrimExpr VisitExpr_(const CallNode* op) final { + Call call = Downcast(StmtExprMutator::VisitExpr_(op)); + if (call->op.same_as(builtin::tvm_access_ptr())) { + return RewriteBufferAccess(call, {1}); + } + return call; + } + + PrimExpr RewriteBufferAccess(const Call& call, const std::vector arg_indices) { + auto product = [](const Array& input) { + return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), input); + }; + Array new_args = call->args; + for (int i : arg_indices) { + auto buffer_var = Downcast(call->args[i]); + if (!buffer_data_to_buffer_.count(buffer_var)) continue; + const Buffer& buffer = buffer_data_to_buffer_[buffer_var]; + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + const Buffer& new_buffer = (*it).second; + const PrimExpr& old_index = call->args[i + 1]; + PrimExpr offset; + if (new_buffer->strides.empty()) { + offset = product(buffer->shape); + } else { + offset = new_buffer->strides[0]; + } + PrimExpr new_index = old_index + version_index_ * offset; + new_args.Set(i + 1, new_index); + } + } + return Call(call->dtype, call->op, new_args, call->span); + } + + PrimExpr version_index_; + Map buffer_data_to_buffer_; + Map> buffer_lca_; + Map buffer_remap_; +}; + +using namespace tir::transform; + +tvm::transform::Pass MultiVersionBuffer() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return MultiVersionBufferRewriter::Substitute(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {}); +} + +TVM_REGISTER_GLOBAL("tl.MultiVersionBuffer").set_body_typed(MultiVersionBuffer); + +} // namespace tl +} // namespace tvm diff --git a/src/tl/transform/warp_specialized_pipeline.cc b/src/tl/transform/warp_specialized_pipeline.cc index b016fe8a62d5..b4561a4cf14d 100644 --- a/src/tl/transform/warp_specialized_pipeline.cc +++ b/src/tl/transform/warp_specialized_pipeline.cc @@ -1,849 +1,849 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file warp_specialized_pipeline.cc - * \brief Warp specialized Pipeline for cuda GPU (sm90+) - */ - -#include -#include -#include -#include -#include - -#include "../op/builtin.h" - -namespace tvm { -namespace tl { - -using namespace tir; - -enum class Role { kConsumer, kProducer, kBoth }; - -class WarpSpecializedRoleMarker : public StmtVisitor { - public: - WarpSpecializedRoleMarker(Map buffer_data_to_buffer) - : buffer_data_to_buffer_(buffer_data_to_buffer) {} - - Role GetRole(const StmtNode* stmt) const { - auto it = map_.find(stmt); - ICHECK(it != map_.end()); - return it->second; - } - - Role GetRole(const Stmt& stmt) const { return GetRole(stmt.get()); } - - void VisitStmt_(const EvaluateNode* op) final { - Role role = Role::kConsumer; - if (auto call = op->value.as()) { - if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { - role = Role::kProducer; - has_bulk_copy_ = true; - } - } - SetRole(op, role); - } - - void VisitStmt_(const BufferStoreNode* op) final { - bool is_shared_store = op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared"; - if (!is_shared_store) { - SetRole(op, Role::kConsumer); - return; - } - - // Check reads from global - Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", - /*body*/ GetRef(op)); - auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); - auto reads = access[0]; - Role role = Role::kProducer; - for (auto read : reads) { - if (read->buffer.scope() != "global") { - role = Role::kConsumer; - break; - } - } - if (role == Role::kProducer) has_simt_copy_ = true; - SetRole(op, role); - } - - void VisitStmt_(const SeqStmtNode* op) final { - StmtVisitor::VisitStmt_(op); - auto role = GetRole(op->seq[0]); - for (auto stmt : op->seq) { - if (role != GetRole(stmt)) { - role = Role::kBoth; - break; - } - } - SetRole(op, role); - } - - void VisitStmt_(const IfThenElseNode* op) final { - StmtVisitor::VisitStmt_(op); - auto role = GetRole(op->then_case); - if (op->else_case.defined()) { - auto role_else = GetRole(op->else_case.value()); - if (role != role_else) role = Role::kBoth; - } - SetRole(op, role); - } - - void VisitStmt_(const BlockRealizeNode* op) final { - StmtVisitor::VisitStmt_(op); - SetRole(op, GetRole(op->block)); - } - - template - void HandleBodyStmt(const NodeType* op) { - StmtVisitor::VisitStmt_(op); - SetRole(op, GetRole(op->body)); - } - - void VisitStmt_(const ForNode* op) final { HandleBodyStmt(op); } - void VisitStmt_(const LetStmtNode* op) final { HandleBodyStmt(op); } - void VisitStmt_(const AttrStmtNode* op) final { HandleBodyStmt(op); } - void VisitStmt_(const AssertStmtNode* op) final { HandleBodyStmt(op); } - void VisitStmt_(const BlockNode* op) final { HandleBodyStmt(op); } - - bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; } - - bool HasSimtCopy() { return has_simt_copy_; } - - private: - void SetRole(const StmtNode* stmt, Role role) { map_[stmt] = role; } - Map buffer_data_to_buffer_; - std::unordered_map map_; - bool has_simt_copy_ = false; - bool has_bulk_copy_ = false; -}; - -static PrimExpr makeGetBarrier(PrimExpr barrier_id) { - return Call(DataType::Handle(), GetMBarrierOp(), {barrier_id}); -} - -static Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) { - auto call = Call(DataType::Handle(), MBarrierExpectTX(), {makeGetBarrier(barrier_id), bytes}); - return Evaluate(call); -} - -static Stmt makeArriveBarrier(PrimExpr barrier_id) { - auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(), {makeGetBarrier(barrier_id)}); - return Evaluate(call); -} - -static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) { - auto call = - Call(DataType::Handle(), builtin::ptx_cp_async_barrier(), {makeGetBarrier(barrier_id)}); - return Evaluate(call); -} - -static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) { - auto call = Call(DataType::Handle(), MBarrierWaitParity(), {makeGetBarrier(barrier_id), parity}); - return Evaluate(call); -} - -class ProducerTraitsCollector : public StmtExprVisitor { - public: - ProducerTraitsCollector() { Clear(); } - - void Clear() { - bulk_copy_bytes = 0; - loop_extents = 1; - has_simt_copy = false; - } - - void Collect(Stmt stmt) { VisitStmt(stmt); } - - bool HasSimtCopy() { return has_simt_copy; } - - PrimExpr BulkCopyBytes() { return bulk_copy_bytes; } - - private: - void VisitExpr_(const CallNode* call) final { - if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { - Call access_ptr = Downcast(call->args[2]); - ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); - int type_bytes = access_ptr->args[0]->dtype.bytes(); - bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes; - } - StmtExprVisitor::VisitExpr_(call); - } - - void VisitStmt_(const ForNode* op) final { - PrimExpr old_loop_evtents = loop_extents; - loop_extents *= op->extent; - StmtExprVisitor::VisitStmt_(op); - loop_extents = old_loop_evtents; - } - - void VisitExpr_(const BufferLoadNode* op) final { - has_simt_copy = true; - StmtExprVisitor::VisitExpr_(op); - } - - bool has_simt_copy; - PrimExpr bulk_copy_bytes; - PrimExpr loop_extents; -}; - -// Rewrite the producer Stmt to use the correct barrier index -class MbarrierRewriter : public StmtExprMutator { - public: - static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) { - MbarrierRewriter rewriter; - rewriter.producer_barrier_idx_ = barrier_id; - return rewriter(stmt); - } - - private: - PrimExpr VisitExpr_(const CallNode* op) final { - auto call = Downcast(StmtExprMutator::VisitExpr_(op)); - if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { - Call access_ptr = Downcast(call->args[2]); - ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); - call.CopyOnWrite()->args.Set(1, makeGetBarrier(producer_barrier_idx_)); - } - return call; - } - PrimExpr producer_barrier_idx_; -}; - -class MultiVersionBufferRewriter : public StmtExprMutator { - public: - static PrimFunc Rewrite(PrimFunc& f) { - auto rewriter = MultiVersionBufferRewriter(); - rewriter.buffer_lca_ = DetectBufferAccessLCA(f); - for (auto [buffer, _] : rewriter.buffer_lca_) { - Var buffer_var = buffer->data; - rewriter.buffer_data_to_buffer_.Set(buffer_var, buffer); - } - f.CopyOnWrite()->body = rewriter(f->body); - return f; - } - - private: - MultiVersionBufferRewriter() = default; - - Array GetVersionedBuffers(Array seq_stmt, Array scoped_buffers) { - std::vector roles; - Array> reads, writes; - auto marker = WarpSpecializedRoleMarker(buffer_data_to_buffer_); - for (auto stmt : seq_stmt) { - marker(stmt); - Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ stmt); - auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_); - reads.push_back(std::move(access[0])); - writes.push_back(std::move(access[1])); - roles.push_back(marker.GetRole(stmt)); - } - - std::unordered_set consumer_used, producer_used; - for (size_t i = 0; i < seq_stmt.size(); i++) { - if (roles[i] == Role::kProducer) { - for (BufferRegion br : writes[i]) producer_used.insert(br->buffer.get()); - } else { - for (BufferRegion br : reads[i]) consumer_used.insert(br->buffer.get()); - } - } - Array versioned_buffers; - for (Buffer buffer : scoped_buffers) { - if (consumer_used.count(buffer.get()) && producer_used.count(buffer.get())) { - versioned_buffers.push_back(buffer); - } - } - return versioned_buffers; - } - - static Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) { - ObjectPtr new_buffer = make_object(*(buffer.get())); - new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); - if (new_buffer->strides.size()) { - ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); - PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1]; - new_buffer->strides.insert(new_buffer->strides.begin(), stride_0); - } - return Buffer(new_buffer); - } - - Stmt VisitStmt_(const BlockRealizeNode* op) final { - BlockRealize block_realize = Downcast(StmtExprMutator::VisitStmt_(op)); - Block block = block_realize->block; - Array alloc_buffers; - for (auto buffer : block->alloc_buffers) { - if (buffer_remap_.count(buffer)) { - Buffer new_buffer = buffer_remap_[buffer]; - alloc_buffers.push_back(new_buffer); - } else { - alloc_buffers.push_back(buffer); - } - } - block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers); - block_realize.CopyOnWrite()->block = block; - return block_realize; - } - - Stmt VisitStmt_(const ForNode* op) final { - auto num_stages_anno = op->annotations.Get("num_stages"); - if (!num_stages_anno.defined()) return StmtExprMutator::VisitStmt_(op); - - ICHECK(num_stages_anno.as()); - int num_stages = static_cast(num_stages_anno.as()->value); - - const SeqStmtNode* pipeline_body_seq = op->body.as(); - CHECK(pipeline_body_seq) - << "ValueError: The body of the software pipeline should be SeqStmt, got " - << op->body->GetTypeKey(); - - Array scoped_buffers = {}; - for (auto [buffer, stmt] : buffer_lca_) { - if (stmt.defined() && stmt.value().get() == op) scoped_buffers.push_back(buffer); - } - - Array versioned_buffers = GetVersionedBuffers(pipeline_body_seq->seq, scoped_buffers); - - for (auto buffer : versioned_buffers) { - Var buffer_var = buffer->data; - Buffer new_buffer = RewriteAllocBuffer(buffer, num_stages); - buffer_remap_.Set(buffer, new_buffer); - } - version_index_ = FloorMod(op->loop_var - op->min, num_stages); - auto for_node = StmtExprMutator::VisitStmt_(op); - - return for_node; - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); - auto it = buffer_remap_.find(load->buffer); - if (it == buffer_remap_.end()) { - return std::move(load); - } - const Buffer& new_buffer = (*it).second; - auto* n = load.CopyOnWrite(); - n->buffer = new_buffer; - n->indices.insert(n->indices.begin(), version_index_); - return std::move(load); - } - - Stmt VisitStmt_(const BufferStoreNode* op) final { - BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); - auto it = buffer_remap_.find(store->buffer); - if (it == buffer_remap_.end()) { - return std::move(store); - } - const Buffer& new_buffer = (*it).second; - auto* n = store.CopyOnWrite(); - n->buffer = new_buffer; - n->indices.insert(n->indices.begin(), version_index_); - return std::move(store); - } - - PrimExpr VisitExpr_(const CallNode* op) final { - Call call = Downcast(StmtExprMutator::VisitExpr_(op)); - if (call->op.same_as(builtin::tvm_access_ptr())) { - return RewriteBufferAccess(call, {1}); - } - return call; - } - - PrimExpr RewriteBufferAccess(const Call& call, const std::vector arg_indices) { - auto product = [](const Array& input) { - return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), input); - }; - Array new_args = call->args; - for (int i : arg_indices) { - auto buffer_var = Downcast(call->args[i]); - if (!buffer_data_to_buffer_.count(buffer_var)) continue; - const Buffer& buffer = buffer_data_to_buffer_[buffer_var]; - auto it = buffer_remap_.find(buffer); - if (it != buffer_remap_.end()) { - const Buffer& new_buffer = (*it).second; - const PrimExpr& old_index = call->args[i + 1]; - PrimExpr offset; - if (new_buffer->strides.empty()) { - offset = product(buffer->shape); - } else { - offset = new_buffer->strides[0]; - } - PrimExpr new_index = old_index + version_index_ * offset; - new_args.Set(i + 1, new_index); - } - } - return Call(call->dtype, call->op, new_args, call->span); - } - - PrimExpr version_index_; - Map buffer_data_to_buffer_; - Map> buffer_lca_; - Map buffer_remap_; -}; - -class ThreadIdxRewriter : public StmtExprMutator { - public: - static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced) { - auto rewriter = ThreadIdxRewriter(thread_var, replaced); - return rewriter(stmt); - } - - private: - ThreadIdxRewriter(Var thread_var, PrimExpr replaced) - : thread_var_(thread_var), replaced_(replaced) {} - - PrimExpr VisitExpr_(const VarNode* var) final { - if (var == thread_var_.get()) { - return replaced_; - } else { - return StmtExprMutator::VisitExpr_(var); - } - } - - Var thread_var_; - PrimExpr replaced_; -}; - -class WSCodeEmitter : public StmtMutator { - public: - WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv, - Map buffer_data_to_buffer, const WarpSpecializedRoleMarker& marker) - : is_emitting_producer_(is_emitting_producer), - buffer_data_to_buffer_(buffer_data_to_buffer), - marker_(marker), - thread_var_(thread_iv->var) {} - - private: - template - Stmt FilterByRole(const NodeType* op) { - Role role = marker_.GetRole(op); - if (role == Role::kBoth) - return StmtMutator::VisitStmt_(op); - else if ((role == Role::kProducer) == is_emitting_producer_) - return GetRef(op); - else - return Evaluate(0); - } - - Stmt VisitStmt_(const SeqStmtNode* op) final { - bool has_producer = false; - for (auto stmt : op->seq) { - if (marker_.GetRole(stmt) == Role::kProducer) { - has_producer = true; - break; - } - } - bool need_producer_sync = has_producer && marker_.GetRole(op) == Role::kBoth; - if (!need_producer_sync) return FilterByRole(op); - - auto seq_transformed = op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); }); - - auto map = ExtractSyncPattern(op->seq); - Array new_body; - - if (is_emitting_producer_) { // producer case - ProducerTraitsCollector collector; - for (int i = 0; i < static_cast(op->seq.size()); i++) { - if (marker_.GetRole(op->seq[i]) == Role::kConsumer) continue; - if (marker_.GetRole(op->seq[i]) == Role::kBoth) { - new_body.push_back(seq_transformed[i]); - continue; - } - if (map.acquire[i] != -1) { - PrimExpr acquire_barrier_id = stage_ + num_barriers_ + num_stages_ * map.acquire[i]; - PrimExpr parity = - map.is_loop_dependency(map.acquire[i]) ? bitwise_xor(parity_, 1) : parity_; - new_body.push_back(makeParityWait(acquire_barrier_id, parity)); - } - ICHECK(map.release[i] >= 0); - PrimExpr release_barrier_id = stage_ + num_barriers_ + num_stages_ * map.release[i]; - auto stmt = MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id); - collector.Collect(stmt); - if (!is_zero(collector.BulkCopyBytes())) { - auto expect_tx = IfThenElse(EQ(thread_var_, 0), - makeExpectTX(release_barrier_id, collector.BulkCopyBytes())); - new_body.push_back(expect_tx); - } - new_body.push_back(stmt); - if (collector.HasSimtCopy() > 0) { - new_body.push_back(makeCpAsyncBarrier(release_barrier_id)); - } - if (map.release_after[i]) { - new_body.push_back(makeArriveBarrier(release_barrier_id)); - for (int j = 0; j < num_stages_; j++) { - released_barrier_.insert(j + num_barriers_ + num_stages_ * map.release[i]); - } - } - collector.Clear(); - } - } else { // consumer case - for (int i = 0; i < static_cast(op->seq.size()); i++) { - if (marker_.GetRole(op->seq[i]) == Role::kProducer) continue; - if (map.acquire[i] != -1) { - PrimExpr acquire_barrier_id = stage_ + num_barriers_ + num_stages_ * map.acquire[i]; - PrimExpr parity = - map.is_loop_dependency(map.acquire[i]) ? bitwise_xor(parity_, 1) : parity_; - new_body.push_back(makeParityWait(acquire_barrier_id, parity)); - } - new_body.push_back(seq_transformed[i]); - if (map.release_after[i]) { - PrimExpr release_barrier_id = stage_ + num_barriers_ + num_stages_ * map.release[i]; - new_body.push_back(makeArriveBarrier(release_barrier_id)); - for (int j = 0; j < num_stages_; j++) { - released_barrier_.insert(j + num_barriers_ + num_stages_ * map.release[i]); - } - } - } - } - - num_barriers_ += map.patterns.size() * num_stages_; - - ICHECK(new_body.size() > 0); - return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)); - } - - Stmt VisitStmt_(const ForNode* op) final { - int num_stages = 1; - auto num_stages_anno = op->annotations.Get("num_stages"); - if (num_stages_anno.defined()) { - ICHECK(num_stages_anno.as()); - num_stages = static_cast(num_stages_anno.as()->value); - ICHECK(num_stages_ == 1) << "Nested pipeline not supported."; - } - - PrimExpr parity_before = std::move(parity_); - PrimExpr stage_before = std::move(stage_); - int num_stages_before = num_stages_; - - num_stages_ = num_stages; - stage_ = FloorMod(op->loop_var - op->min, num_stages); - parity_ = - FloorMod(parity_before * op->extent + FloorDiv(op->loop_var - op->min, num_stages), 2); - - auto result = FilterByRole(op); - - parity_ = std::move(parity_before); - stage_ = std::move(stage_before); - num_stages_ = num_stages_before; - - // remove pipeline annotation - auto for_node = result.as(); - if (result.as()) { - auto for_node = Downcast(result); - for_node.CopyOnWrite()->annotations.erase("num_stages"); - return for_node; - } - return result; - } - - Stmt VisitStmt_(const IfThenElseNode* op) final { return FilterByRole(op); } - Stmt VisitStmt_(const EvaluateNode* op) final { return FilterByRole(op); } - Stmt VisitStmt_(const AttrStmtNode* op) final { return FilterByRole(op); } - Stmt VisitStmt_(const BufferStoreNode* op) final { return FilterByRole(op); } - Stmt VisitStmt_(const LetStmtNode* op) final { return FilterByRole(op); } - Stmt VisitStmt_(const AssertStmtNode* op) final { return FilterByRole(op); } - Stmt VisitStmt_(const BlockNode* op) final { - ICHECK(0); - return Stmt(); - } - Stmt VisitStmt_(const BlockRealizeNode* op) final { - ICHECK(0); - return Stmt(); - } - - struct SyncPattern { - int release_idx, acquire_idx; - }; - - struct SyncPatternMap { - std::vector acquire; - std::vector release; - std::vector release_after; - std::vector patterns; - bool is_loop_dependency(int i) { - // return if the acquire is based on release in the previous iteration - return patterns[i].release_idx > patterns[i].acquire_idx; - } - }; - - std::vector CreateBaseSyncPairs(Array seq_stmt, - const std::vector& is_producer) { - const int n = seq_stmt.size(); - std::vector> reads, writes; - reads.reserve(n); - writes.reserve(n); - for (int i = 0; i < n; i++) { - Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", - /*body*/ seq_stmt[i]); - auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_); - std::set read_set, write_set; - for (auto region : access[0]) read_set.insert(region->buffer.get()); - for (auto region : access[1]) write_set.insert(region->buffer.get()); - reads.push_back(std::move(read_set)); - writes.push_back(std::move(write_set)); - } - - auto intersect_fn = [](const std::set& lhs, - const std::set& rhs) { - for (auto ptr : lhs) - if (rhs.count(ptr)) return true; - return false; - }; - - std::vector sync_patterns; - // producer_release consumer_acquire, - // inject before the first consumer stmt for each producer - for (int i = 0; i < n; i++) { - for (int j = i + 1; j < n; j++) { - if (is_producer[i] != is_producer[j] && - (intersect_fn(writes[i], reads[j]) || intersect_fn(reads[i], writes[j]))) { - sync_patterns.push_back({i, j}); - break; - } - } - } - - // consumer_release producer_acquire - // valid when is_loop is true - // inject before the earlest producer stmt for each consumer - bool in_loop = !is_zero(parity_); - if (in_loop) { - for (int i = 0; i < n; i++) { - for (int j = 0; j < i; j++) { - if (is_producer[i] != is_producer[j] && - (intersect_fn(writes[i], reads[j]) || intersect_fn(reads[i], writes[j]))) { - sync_patterns.push_back({i, j}); - break; - } - } - } - } - - return sync_patterns; - } - - static std::vector RemoveUnusedSyncPatterns( - const std::vector& sync_patterns, const std::vector& is_producer) { - /* - Simplify multiple release-acquire pairs into one - ------------------ - Produce(A) - Produce(B) - Consume(A, B) - ------------------ - [(0, 2), (1, 2), (2, 0)] -> [(1, 2), (2, 0)] - - Or - ------------------ - Produce(A, B) - Consume(A) - Consume(B) - ------------------ - [(0, 1), (1, 0), (2, 0)] -> [(0, 1), (2, 0)] - */ - int M = sync_patterns.size(); - std::vector removed(M, false); - for (int i = 0; i < M; i++) { - for (int j = 0; j < M; j++) { - if (is_producer[sync_patterns[i].acquire_idx] == - is_producer[sync_patterns[j].acquire_idx] && - sync_patterns[i].acquire_idx >= sync_patterns[j].acquire_idx && - sync_patterns[i].release_idx < sync_patterns[j].release_idx) - removed[i] = true; - } - } - - std::vector sync_pattern_cleaned; - sync_pattern_cleaned.reserve(M); - for (int i = 0; i < M; i++) - if (!removed[i]) sync_pattern_cleaned.push_back(sync_patterns[i]); - - return sync_pattern_cleaned; - } - - SyncPatternMap ExtractSyncPattern(Array seq_stmt) { - size_t num_stmts = seq_stmt.size(); - std::vector is_producer; - is_producer.reserve(num_stmts); - for (auto stmt : seq_stmt) { - is_producer.push_back(marker_.GetRole(stmt) == Role::kProducer); - } - - auto sync_patterns_base = CreateBaseSyncPairs(seq_stmt, is_producer); - auto sync_patterns = RemoveUnusedSyncPatterns(sync_patterns_base, is_producer); - - // for (auto pattern : sync_patterns) { - // std::cout << pattern.release_idx << " " << pattern.acquire_idx << std::endl; - // } - - SyncPatternMap map; - map.patterns = sync_patterns; - map.acquire.resize(num_stmts, -1); - map.release.resize(num_stmts, -1); - map.release_after.resize(num_stmts, false); - for (size_t i = 0; i < sync_patterns.size(); i++) { - map.acquire[sync_patterns[i].acquire_idx] = i; - map.release[sync_patterns[i].release_idx] = i; - map.release_after[sync_patterns[i].release_idx] = true; - } - - int cur_consumer_barrier = -1, cur_producer_barrier = -1; - for (int i = num_stmts - 1; i >= 0; i--) { - if (is_producer[i]) { - if (map.release[i] == -1) { - map.release[i] = cur_producer_barrier; - } else { - cur_producer_barrier = map.release[i]; - } - } else { - if (map.release[i] == -1) { - map.release[i] = cur_consumer_barrier; - } else { - cur_consumer_barrier = map.release[i]; - } - } - } - return map; - } - - const bool is_emitting_producer_; - Map buffer_data_to_buffer_; - std::unordered_set released_barrier_; - const WarpSpecializedRoleMarker& marker_; - - int num_barriers_ = 0; - PrimExpr parity_ = 0; - PrimExpr stage_ = 0; - int num_stages_ = 1; - Var thread_var_; - friend class WarpSpecializedPipeline; -}; - -class WarpSpecializedPipeline : public StmtExprMutator { - public: - static PrimFunc Substitute(PrimFunc f) { - f = MultiVersionBufferRewriter::Rewrite(f); - auto T = WarpSpecializedPipeline(); - T.buffer_lca_ = DetectBufferAccessLCA(f); - for (auto [buffer, _] : T.buffer_lca_) T.buffer_data_to_buffer_.Set(buffer->data, buffer); - f.CopyOnWrite()->body = T(f->body); - return f; - } - - private: - Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::thread_extent && - Downcast(op->node)->thread_tag == "threadIdx.x") { - thread_iv_ = Downcast(op->node); - AttrStmt attr_stmt = Downcast(StmtExprMutator::VisitStmt_(op)); - if (updated_thread_extent_.defined()) { - thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()}; - attr_stmt.CopyOnWrite()->node = thread_iv_; - attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value(); - } - thread_iv_ = {}; - return attr_stmt; - } else { - return StmtExprMutator::VisitStmt_(op); - } - } - - Stmt VisitStmt_(const BlockRealizeNode* op) final { - BlockRealize block_realize = Downcast(StmtExprMutator::VisitStmt_(op)); - if (!thread_iv_.defined()) { - return block_realize; - } - ICHECK(!updated_thread_extent_.defined()); - - Block block = block_realize->block; - WarpSpecializedRoleMarker marker(buffer_data_to_buffer_); - marker(block); - if (!marker.HasProducer()) { - // Cannot detect any producer here, directly return. - return block_realize; - } - - WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker); - WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker); - Stmt producer_code = producer(block->body); - Stmt consumer_code = consumer(block->body); - - PrimExpr consumer_thread_extent = thread_iv_->dom->extent; - PrimExpr producer_thread_extent = thread_iv_->dom->extent; - // Need one warp-group for bulk-copy only case - if (!marker.HasSimtCopy()) producer_thread_extent = 128; - - // TODO: estimate the correct reg usage. - auto inc_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), {240, 1})); - auto dec_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), {24, 0})); - - producer_code = SeqStmt({dec_reg_stmt, producer_code}); - consumer_code = SeqStmt({inc_reg_stmt, consumer_code}); - - producer_code = ThreadIdxRewriter::Rewrite(producer_code, thread_iv_->var, - thread_iv_->var - consumer_thread_extent); - updated_thread_extent_ = consumer_thread_extent + producer_thread_extent; - - ICHECK(producer.num_barriers_ == consumer.num_barriers_) - << producer.num_barriers_ << " " << consumer.num_barriers_; - int num_barriers = consumer.num_barriers_; - Array barrier_num_threads; - barrier_num_threads.reserve(num_barriers); - for (int i = 0; i < num_barriers; i++) { - PrimExpr arrive_thread_count = - producer.released_barrier_.count(i) ? producer_thread_extent : consumer_thread_extent; - barrier_num_threads.push_back(arrive_thread_count); - } - - Stmt init_barrier = - Evaluate(Call(DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads)); - Stmt body = - IfThenElse(GE(thread_iv_->var, consumer_thread_extent), producer_code, consumer_code); - // Add an attr here to handle the partial thread count in THreadSync pass. - Array ws_partition = {Downcast(producer_thread_extent), - Downcast(consumer_thread_extent)}; - body = AttrStmt(ws_partition, "kWarpSpecializationScope", 0, body); - - block.CopyOnWrite()->body = SeqStmt({init_barrier, body}); - block_realize.CopyOnWrite()->block = block; - return block_realize; - } - - WarpSpecializedPipeline() = default; - - Map buffer_data_to_buffer_; - Map> buffer_lca_; - Map buffer_remap_; - IterVar thread_iv_; - Optional updated_thread_extent_; -}; - -using namespace tir::transform; - -tvm::transform::Pass WarpSpecializedPipeline() { - auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return WarpSpecializedPipeline::Substitute(f); - }; - return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecializedPipeline", {}); -} - -TVM_REGISTER_GLOBAL("tl.WarpSpecializedPipeline").set_body_typed(WarpSpecializedPipeline); - -} // namespace tl -} // namespace tvm +// /* +// * Licensed to the Apache Software Foundation (ASF) under one +// * or more contributor license agreements. See the NOTICE file +// * distributed with this work for additional information +// * regarding copyright ownership. The ASF licenses this file +// * to you under the Apache License, Version 2.0 (the +// * "License"); you may not use this file except in compliance +// * with the License. You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, +// * software distributed under the License is distributed on an +// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// * KIND, either express or implied. See the License for the +// * specific language governing permissions and limitations +// * under the License. +// */ + +// /*! +// * \file warp_specialized_pipeline.cc +// * \brief Warp specialized Pipeline for cuda GPU (sm90+) +// */ + +// #include +// #include +// #include +// #include +// #include + +// #include "../op/builtin.h" + +// namespace tvm { +// namespace tl { + +// using namespace tir; + +// enum class Role { kConsumer, kProducer, kBoth }; + +// class WarpSpecializedRoleMarker : public StmtVisitor { +// public: +// WarpSpecializedRoleMarker(Map buffer_data_to_buffer) +// : buffer_data_to_buffer_(buffer_data_to_buffer) {} + +// Role GetRole(const StmtNode* stmt) const { +// auto it = map_.find(stmt); +// ICHECK(it != map_.end()); +// return it->second; +// } + +// Role GetRole(const Stmt& stmt) const { return GetRole(stmt.get()); } + +// void VisitStmt_(const EvaluateNode* op) final { +// Role role = Role::kConsumer; +// if (auto call = op->value.as()) { +// if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { +// role = Role::kProducer; +// has_bulk_copy_ = true; +// } +// } +// SetRole(op, role); +// } + +// void VisitStmt_(const BufferStoreNode* op) final { +// bool is_shared_store = op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared"; +// if (!is_shared_store) { +// SetRole(op, Role::kConsumer); +// return; +// } + +// // Check reads from global +// Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", +// /*body*/ GetRef(op)); +// auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); +// auto reads = access[0]; +// Role role = Role::kProducer; +// for (auto read : reads) { +// if (read->buffer.scope() != "global") { +// role = Role::kConsumer; +// break; +// } +// } +// if (role == Role::kProducer) has_simt_copy_ = true; +// SetRole(op, role); +// } + +// void VisitStmt_(const SeqStmtNode* op) final { +// StmtVisitor::VisitStmt_(op); +// auto role = GetRole(op->seq[0]); +// for (auto stmt : op->seq) { +// if (role != GetRole(stmt)) { +// role = Role::kBoth; +// break; +// } +// } +// SetRole(op, role); +// } + +// void VisitStmt_(const IfThenElseNode* op) final { +// StmtVisitor::VisitStmt_(op); +// auto role = GetRole(op->then_case); +// if (op->else_case.defined()) { +// auto role_else = GetRole(op->else_case.value()); +// if (role != role_else) role = Role::kBoth; +// } +// SetRole(op, role); +// } + +// void VisitStmt_(const BlockRealizeNode* op) final { +// StmtVisitor::VisitStmt_(op); +// SetRole(op, GetRole(op->block)); +// } + +// template +// void HandleBodyStmt(const NodeType* op) { +// StmtVisitor::VisitStmt_(op); +// SetRole(op, GetRole(op->body)); +// } + +// void VisitStmt_(const ForNode* op) final { HandleBodyStmt(op); } +// void VisitStmt_(const LetStmtNode* op) final { HandleBodyStmt(op); } +// void VisitStmt_(const AttrStmtNode* op) final { HandleBodyStmt(op); } +// void VisitStmt_(const AssertStmtNode* op) final { HandleBodyStmt(op); } +// void VisitStmt_(const BlockNode* op) final { HandleBodyStmt(op); } + +// bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; } + +// bool HasSimtCopy() { return has_simt_copy_; } + +// private: +// void SetRole(const StmtNode* stmt, Role role) { map_[stmt] = role; } +// Map buffer_data_to_buffer_; +// std::unordered_map map_; +// bool has_simt_copy_ = false; +// bool has_bulk_copy_ = false; +// }; + +// static PrimExpr makeGetBarrier(PrimExpr barrier_id) { +// return Call(DataType::Handle(), GetMBarrierOp(), {barrier_id}); +// } + +// static Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) { +// auto call = Call(DataType::Handle(), MBarrierExpectTX(), {makeGetBarrier(barrier_id), bytes}); +// return Evaluate(call); +// } + +// static Stmt makeArriveBarrier(PrimExpr barrier_id) { +// auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(), {makeGetBarrier(barrier_id)}); +// return Evaluate(call); +// } + +// static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) { +// auto call = +// Call(DataType::Handle(), builtin::ptx_cp_async_barrier(), {makeGetBarrier(barrier_id)}); +// return Evaluate(call); +// } + +// static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) { +// auto call = Call(DataType::Handle(), MBarrierWaitParity(), {makeGetBarrier(barrier_id), parity}); +// return Evaluate(call); +// } + +// class ProducerTraitsCollector : public StmtExprVisitor { +// public: +// ProducerTraitsCollector() { Clear(); } + +// void Clear() { +// bulk_copy_bytes = 0; +// loop_extents = 1; +// has_simt_copy = false; +// } + +// void Collect(Stmt stmt) { VisitStmt(stmt); } + +// bool HasSimtCopy() { return has_simt_copy; } + +// PrimExpr BulkCopyBytes() { return bulk_copy_bytes; } + +// private: +// void VisitExpr_(const CallNode* call) final { +// if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { +// Call access_ptr = Downcast(call->args[2]); +// ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); +// int type_bytes = access_ptr->args[0]->dtype.bytes(); +// bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes; +// } +// StmtExprVisitor::VisitExpr_(call); +// } + +// void VisitStmt_(const ForNode* op) final { +// PrimExpr old_loop_evtents = loop_extents; +// loop_extents *= op->extent; +// StmtExprVisitor::VisitStmt_(op); +// loop_extents = old_loop_evtents; +// } + +// void VisitExpr_(const BufferLoadNode* op) final { +// has_simt_copy = true; +// StmtExprVisitor::VisitExpr_(op); +// } + +// bool has_simt_copy; +// PrimExpr bulk_copy_bytes; +// PrimExpr loop_extents; +// }; + +// // Rewrite the producer Stmt to use the correct barrier index +// class MbarrierRewriter : public StmtExprMutator { +// public: +// static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) { +// MbarrierRewriter rewriter; +// rewriter.producer_barrier_idx_ = barrier_id; +// return rewriter(stmt); +// } + +// private: +// PrimExpr VisitExpr_(const CallNode* op) final { +// auto call = Downcast(StmtExprMutator::VisitExpr_(op)); +// if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { +// Call access_ptr = Downcast(call->args[2]); +// ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); +// call.CopyOnWrite()->args.Set(1, makeGetBarrier(producer_barrier_idx_)); +// } +// return call; +// } +// PrimExpr producer_barrier_idx_; +// }; + +// class MultiVersionBufferRewriter : public StmtExprMutator { +// public: +// static PrimFunc Rewrite(PrimFunc& f) { +// auto rewriter = MultiVersionBufferRewriter(); +// rewriter.buffer_lca_ = DetectBufferAccessLCA(f); +// for (auto [buffer, _] : rewriter.buffer_lca_) { +// Var buffer_var = buffer->data; +// rewriter.buffer_data_to_buffer_.Set(buffer_var, buffer); +// } +// f.CopyOnWrite()->body = rewriter(f->body); +// return f; +// } + +// private: +// MultiVersionBufferRewriter() = default; + +// Array GetVersionedBuffers(Array seq_stmt, Array scoped_buffers) { +// std::vector roles; +// Array> reads, writes; +// auto marker = WarpSpecializedRoleMarker(buffer_data_to_buffer_); +// for (auto stmt : seq_stmt) { +// marker(stmt); +// Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ stmt); +// auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_); +// reads.push_back(std::move(access[0])); +// writes.push_back(std::move(access[1])); +// roles.push_back(marker.GetRole(stmt)); +// } + +// std::unordered_set consumer_used, producer_used; +// for (size_t i = 0; i < seq_stmt.size(); i++) { +// if (roles[i] == Role::kProducer) { +// for (BufferRegion br : writes[i]) producer_used.insert(br->buffer.get()); +// } else { +// for (BufferRegion br : reads[i]) consumer_used.insert(br->buffer.get()); +// } +// } +// Array versioned_buffers; +// for (Buffer buffer : scoped_buffers) { +// if (consumer_used.count(buffer.get()) && producer_used.count(buffer.get())) { +// versioned_buffers.push_back(buffer); +// } +// } +// return versioned_buffers; +// } + +// static Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) { +// ObjectPtr new_buffer = make_object(*(buffer.get())); +// new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); +// if (new_buffer->strides.size()) { +// ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); +// PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1]; +// new_buffer->strides.insert(new_buffer->strides.begin(), stride_0); +// } +// return Buffer(new_buffer); +// } + +// Stmt VisitStmt_(const BlockRealizeNode* op) final { +// BlockRealize block_realize = Downcast(StmtExprMutator::VisitStmt_(op)); +// Block block = block_realize->block; +// Array alloc_buffers; +// for (auto buffer : block->alloc_buffers) { +// if (buffer_remap_.count(buffer)) { +// Buffer new_buffer = buffer_remap_[buffer]; +// alloc_buffers.push_back(new_buffer); +// } else { +// alloc_buffers.push_back(buffer); +// } +// } +// block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers); +// block_realize.CopyOnWrite()->block = block; +// return block_realize; +// } + +// Stmt VisitStmt_(const ForNode* op) final { +// auto num_stages_anno = op->annotations.Get("num_stages"); +// if (!num_stages_anno.defined()) return StmtExprMutator::VisitStmt_(op); + +// ICHECK(num_stages_anno.as()); +// int num_stages = static_cast(num_stages_anno.as()->value); + +// const SeqStmtNode* pipeline_body_seq = op->body.as(); +// CHECK(pipeline_body_seq) +// << "ValueError: The body of the software pipeline should be SeqStmt, got " +// << op->body->GetTypeKey(); + +// Array scoped_buffers = {}; +// for (auto [buffer, stmt] : buffer_lca_) { +// if (stmt.defined() && stmt.value().get() == op) scoped_buffers.push_back(buffer); +// } + +// Array versioned_buffers = GetVersionedBuffers(pipeline_body_seq->seq, scoped_buffers); + +// for (auto buffer : versioned_buffers) { +// Var buffer_var = buffer->data; +// Buffer new_buffer = RewriteAllocBuffer(buffer, num_stages); +// buffer_remap_.Set(buffer, new_buffer); +// } +// version_index_ = FloorMod(op->loop_var - op->min, num_stages); +// auto for_node = StmtExprMutator::VisitStmt_(op); + +// return for_node; +// } + +// PrimExpr VisitExpr_(const BufferLoadNode* op) final { +// BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); +// auto it = buffer_remap_.find(load->buffer); +// if (it == buffer_remap_.end()) { +// return std::move(load); +// } +// const Buffer& new_buffer = (*it).second; +// auto* n = load.CopyOnWrite(); +// n->buffer = new_buffer; +// n->indices.insert(n->indices.begin(), version_index_); +// return std::move(load); +// } + +// Stmt VisitStmt_(const BufferStoreNode* op) final { +// BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); +// auto it = buffer_remap_.find(store->buffer); +// if (it == buffer_remap_.end()) { +// return std::move(store); +// } +// const Buffer& new_buffer = (*it).second; +// auto* n = store.CopyOnWrite(); +// n->buffer = new_buffer; +// n->indices.insert(n->indices.begin(), version_index_); +// return std::move(store); +// } + +// PrimExpr VisitExpr_(const CallNode* op) final { +// Call call = Downcast(StmtExprMutator::VisitExpr_(op)); +// if (call->op.same_as(builtin::tvm_access_ptr())) { +// return RewriteBufferAccess(call, {1}); +// } +// return call; +// } + +// PrimExpr RewriteBufferAccess(const Call& call, const std::vector arg_indices) { +// auto product = [](const Array& input) { +// return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, +// make_const(DataType::Int(32), 1), input); +// }; +// Array new_args = call->args; +// for (int i : arg_indices) { +// auto buffer_var = Downcast(call->args[i]); +// if (!buffer_data_to_buffer_.count(buffer_var)) continue; +// const Buffer& buffer = buffer_data_to_buffer_[buffer_var]; +// auto it = buffer_remap_.find(buffer); +// if (it != buffer_remap_.end()) { +// const Buffer& new_buffer = (*it).second; +// const PrimExpr& old_index = call->args[i + 1]; +// PrimExpr offset; +// if (new_buffer->strides.empty()) { +// offset = product(buffer->shape); +// } else { +// offset = new_buffer->strides[0]; +// } +// PrimExpr new_index = old_index + version_index_ * offset; +// new_args.Set(i + 1, new_index); +// } +// } +// return Call(call->dtype, call->op, new_args, call->span); +// } + +// PrimExpr version_index_; +// Map buffer_data_to_buffer_; +// Map> buffer_lca_; +// Map buffer_remap_; +// }; + +// class ThreadIdxRewriter : public StmtExprMutator { +// public: +// static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced) { +// auto rewriter = ThreadIdxRewriter(thread_var, replaced); +// return rewriter(stmt); +// } + +// private: +// ThreadIdxRewriter(Var thread_var, PrimExpr replaced) +// : thread_var_(thread_var), replaced_(replaced) {} + +// PrimExpr VisitExpr_(const VarNode* var) final { +// if (var == thread_var_.get()) { +// return replaced_; +// } else { +// return StmtExprMutator::VisitExpr_(var); +// } +// } + +// Var thread_var_; +// PrimExpr replaced_; +// }; + +// class WSCodeEmitter : public StmtMutator { +// public: +// WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv, +// Map buffer_data_to_buffer, const WarpSpecializedRoleMarker& marker) +// : is_emitting_producer_(is_emitting_producer), +// buffer_data_to_buffer_(buffer_data_to_buffer), +// marker_(marker), +// thread_var_(thread_iv->var) {} + +// private: +// template +// Stmt FilterByRole(const NodeType* op) { +// Role role = marker_.GetRole(op); +// if (role == Role::kBoth) +// return StmtMutator::VisitStmt_(op); +// else if ((role == Role::kProducer) == is_emitting_producer_) +// return GetRef(op); +// else +// return Evaluate(0); +// } + +// Stmt VisitStmt_(const SeqStmtNode* op) final { +// bool has_producer = false; +// for (auto stmt : op->seq) { +// if (marker_.GetRole(stmt) == Role::kProducer) { +// has_producer = true; +// break; +// } +// } +// bool need_producer_sync = has_producer && marker_.GetRole(op) == Role::kBoth; +// if (!need_producer_sync) return FilterByRole(op); + +// auto seq_transformed = op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); }); + +// auto map = ExtractSyncPattern(op->seq); +// Array new_body; + +// if (is_emitting_producer_) { // producer case +// ProducerTraitsCollector collector; +// for (int i = 0; i < static_cast(op->seq.size()); i++) { +// if (marker_.GetRole(op->seq[i]) == Role::kConsumer) continue; +// if (marker_.GetRole(op->seq[i]) == Role::kBoth) { +// new_body.push_back(seq_transformed[i]); +// continue; +// } +// if (map.acquire[i] != -1) { +// PrimExpr acquire_barrier_id = stage_ + num_barriers_ + num_stages_ * map.acquire[i]; +// PrimExpr parity = +// map.is_loop_dependency(map.acquire[i]) ? bitwise_xor(parity_, 1) : parity_; +// new_body.push_back(makeParityWait(acquire_barrier_id, parity)); +// } +// ICHECK(map.release[i] >= 0); +// PrimExpr release_barrier_id = stage_ + num_barriers_ + num_stages_ * map.release[i]; +// auto stmt = MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id); +// collector.Collect(stmt); +// if (!is_zero(collector.BulkCopyBytes())) { +// auto expect_tx = IfThenElse(EQ(thread_var_, 0), +// makeExpectTX(release_barrier_id, collector.BulkCopyBytes())); +// new_body.push_back(expect_tx); +// } +// new_body.push_back(stmt); +// if (collector.HasSimtCopy() > 0) { +// new_body.push_back(makeCpAsyncBarrier(release_barrier_id)); +// } +// if (map.release_after[i]) { +// new_body.push_back(makeArriveBarrier(release_barrier_id)); +// for (int j = 0; j < num_stages_; j++) { +// released_barrier_.insert(j + num_barriers_ + num_stages_ * map.release[i]); +// } +// } +// collector.Clear(); +// } +// } else { // consumer case +// for (int i = 0; i < static_cast(op->seq.size()); i++) { +// if (marker_.GetRole(op->seq[i]) == Role::kProducer) continue; +// if (map.acquire[i] != -1) { +// PrimExpr acquire_barrier_id = stage_ + num_barriers_ + num_stages_ * map.acquire[i]; +// PrimExpr parity = +// map.is_loop_dependency(map.acquire[i]) ? bitwise_xor(parity_, 1) : parity_; +// new_body.push_back(makeParityWait(acquire_barrier_id, parity)); +// } +// new_body.push_back(seq_transformed[i]); +// if (map.release_after[i]) { +// PrimExpr release_barrier_id = stage_ + num_barriers_ + num_stages_ * map.release[i]; +// new_body.push_back(makeArriveBarrier(release_barrier_id)); +// for (int j = 0; j < num_stages_; j++) { +// released_barrier_.insert(j + num_barriers_ + num_stages_ * map.release[i]); +// } +// } +// } +// } + +// num_barriers_ += map.patterns.size() * num_stages_; + +// ICHECK(new_body.size() > 0); +// return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)); +// } + +// Stmt VisitStmt_(const ForNode* op) final { +// int num_stages = 1; +// auto num_stages_anno = op->annotations.Get("num_stages"); +// if (num_stages_anno.defined()) { +// ICHECK(num_stages_anno.as()); +// num_stages = static_cast(num_stages_anno.as()->value); +// ICHECK(num_stages_ == 1) << "Nested pipeline not supported."; +// } + +// PrimExpr parity_before = std::move(parity_); +// PrimExpr stage_before = std::move(stage_); +// int num_stages_before = num_stages_; + +// num_stages_ = num_stages; +// stage_ = FloorMod(op->loop_var - op->min, num_stages); +// parity_ = +// FloorMod(parity_before * op->extent + FloorDiv(op->loop_var - op->min, num_stages), 2); + +// auto result = FilterByRole(op); + +// parity_ = std::move(parity_before); +// stage_ = std::move(stage_before); +// num_stages_ = num_stages_before; + +// // remove pipeline annotation +// auto for_node = result.as(); +// if (result.as()) { +// auto for_node = Downcast(result); +// for_node.CopyOnWrite()->annotations.erase("num_stages"); +// return for_node; +// } +// return result; +// } + +// Stmt VisitStmt_(const IfThenElseNode* op) final { return FilterByRole(op); } +// Stmt VisitStmt_(const EvaluateNode* op) final { return FilterByRole(op); } +// Stmt VisitStmt_(const AttrStmtNode* op) final { return FilterByRole(op); } +// Stmt VisitStmt_(const BufferStoreNode* op) final { return FilterByRole(op); } +// Stmt VisitStmt_(const LetStmtNode* op) final { return FilterByRole(op); } +// Stmt VisitStmt_(const AssertStmtNode* op) final { return FilterByRole(op); } +// Stmt VisitStmt_(const BlockNode* op) final { +// ICHECK(0); +// return Stmt(); +// } +// Stmt VisitStmt_(const BlockRealizeNode* op) final { +// ICHECK(0); +// return Stmt(); +// } + +// struct SyncPattern { +// int release_idx, acquire_idx; +// }; + +// struct SyncPatternMap { +// std::vector acquire; +// std::vector release; +// std::vector release_after; +// std::vector patterns; +// bool is_loop_dependency(int i) { +// // return if the acquire is based on release in the previous iteration +// return patterns[i].release_idx > patterns[i].acquire_idx; +// } +// }; + +// std::vector CreateBaseSyncPairs(Array seq_stmt, +// const std::vector& is_producer) { +// const int n = seq_stmt.size(); +// std::vector> reads, writes; +// reads.reserve(n); +// writes.reserve(n); +// for (int i = 0; i < n; i++) { +// Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", +// /*body*/ seq_stmt[i]); +// auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_); +// std::set read_set, write_set; +// for (auto region : access[0]) read_set.insert(region->buffer.get()); +// for (auto region : access[1]) write_set.insert(region->buffer.get()); +// reads.push_back(std::move(read_set)); +// writes.push_back(std::move(write_set)); +// } + +// auto intersect_fn = [](const std::set& lhs, +// const std::set& rhs) { +// for (auto ptr : lhs) +// if (rhs.count(ptr)) return true; +// return false; +// }; + +// std::vector sync_patterns; +// // producer_release consumer_acquire, +// // inject before the first consumer stmt for each producer +// for (int i = 0; i < n; i++) { +// for (int j = i + 1; j < n; j++) { +// if (is_producer[i] != is_producer[j] && +// (intersect_fn(writes[i], reads[j]) || intersect_fn(reads[i], writes[j]))) { +// sync_patterns.push_back({i, j}); +// break; +// } +// } +// } + +// // consumer_release producer_acquire +// // valid when is_loop is true +// // inject before the earlest producer stmt for each consumer +// bool in_loop = !is_zero(parity_); +// if (in_loop) { +// for (int i = 0; i < n; i++) { +// for (int j = 0; j < i; j++) { +// if (is_producer[i] != is_producer[j] && +// (intersect_fn(writes[i], reads[j]) || intersect_fn(reads[i], writes[j]))) { +// sync_patterns.push_back({i, j}); +// break; +// } +// } +// } +// } + +// return sync_patterns; +// } + +// static std::vector RemoveUnusedSyncPatterns( +// const std::vector& sync_patterns, const std::vector& is_producer) { +// /* +// Simplify multiple release-acquire pairs into one +// ------------------ +// Produce(A) +// Produce(B) +// Consume(A, B) +// ------------------ +// [(0, 2), (1, 2), (2, 0)] -> [(1, 2), (2, 0)] + +// Or +// ------------------ +// Produce(A, B) +// Consume(A) +// Consume(B) +// ------------------ +// [(0, 1), (1, 0), (2, 0)] -> [(0, 1), (2, 0)] +// */ +// int M = sync_patterns.size(); +// std::vector removed(M, false); +// for (int i = 0; i < M; i++) { +// for (int j = 0; j < M; j++) { +// if (is_producer[sync_patterns[i].acquire_idx] == +// is_producer[sync_patterns[j].acquire_idx] && +// sync_patterns[i].acquire_idx >= sync_patterns[j].acquire_idx && +// sync_patterns[i].release_idx < sync_patterns[j].release_idx) +// removed[i] = true; +// } +// } + +// std::vector sync_pattern_cleaned; +// sync_pattern_cleaned.reserve(M); +// for (int i = 0; i < M; i++) +// if (!removed[i]) sync_pattern_cleaned.push_back(sync_patterns[i]); + +// return sync_pattern_cleaned; +// } + +// SyncPatternMap ExtractSyncPattern(Array seq_stmt) { +// size_t num_stmts = seq_stmt.size(); +// std::vector is_producer; +// is_producer.reserve(num_stmts); +// for (auto stmt : seq_stmt) { +// is_producer.push_back(marker_.GetRole(stmt) == Role::kProducer); +// } + +// auto sync_patterns_base = CreateBaseSyncPairs(seq_stmt, is_producer); +// auto sync_patterns = RemoveUnusedSyncPatterns(sync_patterns_base, is_producer); + +// // for (auto pattern : sync_patterns) { +// // std::cout << pattern.release_idx << " " << pattern.acquire_idx << std::endl; +// // } + +// SyncPatternMap map; +// map.patterns = sync_patterns; +// map.acquire.resize(num_stmts, -1); +// map.release.resize(num_stmts, -1); +// map.release_after.resize(num_stmts, false); +// for (size_t i = 0; i < sync_patterns.size(); i++) { +// map.acquire[sync_patterns[i].acquire_idx] = i; +// map.release[sync_patterns[i].release_idx] = i; +// map.release_after[sync_patterns[i].release_idx] = true; +// } + +// int cur_consumer_barrier = -1, cur_producer_barrier = -1; +// for (int i = num_stmts - 1; i >= 0; i--) { +// if (is_producer[i]) { +// if (map.release[i] == -1) { +// map.release[i] = cur_producer_barrier; +// } else { +// cur_producer_barrier = map.release[i]; +// } +// } else { +// if (map.release[i] == -1) { +// map.release[i] = cur_consumer_barrier; +// } else { +// cur_consumer_barrier = map.release[i]; +// } +// } +// } +// return map; +// } + +// const bool is_emitting_producer_; +// Map buffer_data_to_buffer_; +// std::unordered_set released_barrier_; +// const WarpSpecializedRoleMarker& marker_; + +// int num_barriers_ = 0; +// PrimExpr parity_ = 0; +// PrimExpr stage_ = 0; +// int num_stages_ = 1; +// Var thread_var_; +// friend class WarpSpecializedPipeline; +// }; + +// class WarpSpecializedPipeline : public StmtExprMutator { +// public: +// static PrimFunc Substitute(PrimFunc f) { +// f = MultiVersionBufferRewriter::Rewrite(f); +// auto T = WarpSpecializedPipeline(); +// T.buffer_lca_ = DetectBufferAccessLCA(f); +// for (auto [buffer, _] : T.buffer_lca_) T.buffer_data_to_buffer_.Set(buffer->data, buffer); +// f.CopyOnWrite()->body = T(f->body); +// return f; +// } + +// private: +// Stmt VisitStmt_(const AttrStmtNode* op) final { +// if (op->attr_key == tir::attr::thread_extent && +// Downcast(op->node)->thread_tag == "threadIdx.x") { +// thread_iv_ = Downcast(op->node); +// AttrStmt attr_stmt = Downcast(StmtExprMutator::VisitStmt_(op)); +// if (updated_thread_extent_.defined()) { +// thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()}; +// attr_stmt.CopyOnWrite()->node = thread_iv_; +// attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value(); +// } +// thread_iv_ = {}; +// return attr_stmt; +// } else { +// return StmtExprMutator::VisitStmt_(op); +// } +// } + +// Stmt VisitStmt_(const BlockRealizeNode* op) final { +// BlockRealize block_realize = Downcast(StmtExprMutator::VisitStmt_(op)); +// if (!thread_iv_.defined()) { +// return block_realize; +// } +// ICHECK(!updated_thread_extent_.defined()); + +// Block block = block_realize->block; +// WarpSpecializedRoleMarker marker(buffer_data_to_buffer_); +// marker(block); +// if (!marker.HasProducer()) { +// // Cannot detect any producer here, directly return. +// return block_realize; +// } + +// WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker); +// WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker); +// Stmt producer_code = producer(block->body); +// Stmt consumer_code = consumer(block->body); + +// PrimExpr consumer_thread_extent = thread_iv_->dom->extent; +// PrimExpr producer_thread_extent = thread_iv_->dom->extent; +// // Need one warp-group for bulk-copy only case +// if (!marker.HasSimtCopy()) producer_thread_extent = 128; + +// // TODO: estimate the correct reg usage. +// auto inc_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), {240, 1})); +// auto dec_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), {24, 0})); + +// producer_code = SeqStmt({dec_reg_stmt, producer_code}); +// consumer_code = SeqStmt({inc_reg_stmt, consumer_code}); + +// producer_code = ThreadIdxRewriter::Rewrite(producer_code, thread_iv_->var, +// thread_iv_->var - consumer_thread_extent); +// updated_thread_extent_ = consumer_thread_extent + producer_thread_extent; + +// ICHECK(producer.num_barriers_ == consumer.num_barriers_) +// << producer.num_barriers_ << " " << consumer.num_barriers_; +// int num_barriers = consumer.num_barriers_; +// Array barrier_num_threads; +// barrier_num_threads.reserve(num_barriers); +// for (int i = 0; i < num_barriers; i++) { +// PrimExpr arrive_thread_count = +// producer.released_barrier_.count(i) ? producer_thread_extent : consumer_thread_extent; +// barrier_num_threads.push_back(arrive_thread_count); +// } + +// Stmt init_barrier = +// Evaluate(Call(DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads)); +// Stmt body = +// IfThenElse(GE(thread_iv_->var, consumer_thread_extent), producer_code, consumer_code); +// // Add an attr here to handle the partial thread count in THreadSync pass. +// Array ws_partition = {Downcast(producer_thread_extent), +// Downcast(consumer_thread_extent)}; +// body = AttrStmt(ws_partition, "kWarpSpecializationScope", 0, body); + +// block.CopyOnWrite()->body = SeqStmt({init_barrier, body}); +// block_realize.CopyOnWrite()->block = block; +// return block_realize; +// } + +// WarpSpecializedPipeline() = default; + +// Map buffer_data_to_buffer_; +// Map> buffer_lca_; +// Map buffer_remap_; +// IterVar thread_iv_; +// Optional updated_thread_extent_; +// }; + +// using namespace tir::transform; + +// tvm::transform::Pass WarpSpecializedPipeline() { +// auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { +// return WarpSpecializedPipeline::Substitute(f); +// }; +// return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecializedPipeline", {}); +// } + +// TVM_REGISTER_GLOBAL("tl.WarpSpecializedPipeline").set_body_typed(WarpSpecializedPipeline); + +// } // namespace tl +// } // namespace tvm diff --git a/src/tl/transform/warp_specialized_rewriter.cc b/src/tl/transform/warp_specialized_rewriter.cc new file mode 100644 index 000000000000..37b1d6a9b5b2 --- /dev/null +++ b/src/tl/transform/warp_specialized_rewriter.cc @@ -0,0 +1,738 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file warp_specialized_pipeline.cc + * \brief Warp specialized Pipeline for cuda GPU (sm90+) + */ + +#include +#include +#include +#include +#include + +#include "../op/builtin.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +enum class Role { kConsumer, kProducer, kBoth }; + +class WarpSpecializedRoleMarker : public StmtVisitor { + public: + WarpSpecializedRoleMarker(Map buffer_data_to_buffer) + : buffer_data_to_buffer_(buffer_data_to_buffer) {} + + Role GetRole(const StmtNode* stmt) const { + auto it = map_.find(stmt); + ICHECK(it != map_.end()); + return it->second; + } + + Role GetRole(const Stmt& stmt) const { return GetRole(stmt.get()); } + + void VisitStmt_(const EvaluateNode* op) final { + Role role = Role::kConsumer; + if (auto call = op->value.as()) { + if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { + role = Role::kProducer; + has_bulk_copy_ = true; + } + } + SetRole(op, role); + } + + void VisitStmt_(const BufferStoreNode* op) final { + bool is_shared_store = op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared"; + if (!is_shared_store) { + SetRole(op, Role::kConsumer); + return; + } + + // Check reads from global + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", + /*body*/ GetRef(op)); + auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + auto reads = access[0]; + Role role = Role::kProducer; + for (auto read : reads) { + if (read->buffer.scope() != "global") { + role = Role::kConsumer; + break; + } + } + if (role == Role::kProducer) has_simt_copy_ = true; + SetRole(op, role); + } + + void VisitStmt_(const SeqStmtNode* op) final { + StmtVisitor::VisitStmt_(op); + auto role = GetRole(op->seq[0]); + for (auto stmt : op->seq) { + if (role != GetRole(stmt)) { + role = Role::kBoth; + break; + } + } + SetRole(op, role); + } + + void VisitStmt_(const IfThenElseNode* op) final { + StmtVisitor::VisitStmt_(op); + auto role = GetRole(op->then_case); + if (op->else_case.defined()) { + auto role_else = GetRole(op->else_case.value()); + if (role != role_else) role = Role::kBoth; + } + SetRole(op, role); + } + + void VisitStmt_(const BlockRealizeNode* op) final { + StmtVisitor::VisitStmt_(op); + SetRole(op, GetRole(op->block)); + } + + template + void HandleBodyStmt(const NodeType* op) { + StmtVisitor::VisitStmt_(op); + SetRole(op, GetRole(op->body)); + } + + void VisitStmt_(const ForNode* op) final { HandleBodyStmt(op); } + void VisitStmt_(const LetStmtNode* op) final { HandleBodyStmt(op); } + void VisitStmt_(const AttrStmtNode* op) final { HandleBodyStmt(op); } + void VisitStmt_(const AssertStmtNode* op) final { HandleBodyStmt(op); } + void VisitStmt_(const BlockNode* op) final { HandleBodyStmt(op); } + + bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; } + + bool HasSimtCopy() { return has_simt_copy_; } + + private: + void SetRole(const StmtNode* stmt, Role role) { map_[stmt] = role; } + Map buffer_data_to_buffer_; + std::unordered_map map_; + bool has_simt_copy_ = false; + bool has_bulk_copy_ = false; +}; + +static PrimExpr makeGetBarrier(PrimExpr barrier_id) { + return Call(DataType::Handle(), GetMBarrierOp(), {barrier_id}); +} + +static Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) { + auto call = Call(DataType::Handle(), MBarrierExpectTX(), {makeGetBarrier(barrier_id), bytes}); + return Evaluate(call); +} + +static Stmt makeArriveBarrier(PrimExpr barrier_id) { + auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(), {makeGetBarrier(barrier_id)}); + return Evaluate(call); +} + +static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) { + auto call = + Call(DataType::Handle(), builtin::ptx_cp_async_barrier(), {makeGetBarrier(barrier_id)}); + return Evaluate(call); +} + +static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) { + auto call = Call(DataType::Handle(), MBarrierWaitParity(), {makeGetBarrier(barrier_id), parity}); + return Evaluate(call); +} + +class ProducerTraitsCollector : public StmtExprVisitor { + public: + ProducerTraitsCollector() { Clear(); } + + void Clear() { + bulk_copy_bytes = 0; + loop_extents = 1; + has_simt_copy = false; + } + + void Collect(Stmt stmt) { VisitStmt(stmt); } + + bool HasSimtCopy() { return has_simt_copy; } + + PrimExpr BulkCopyBytes() { return bulk_copy_bytes; } + + private: + void VisitExpr_(const CallNode* call) final { + if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { + Call access_ptr = Downcast(call->args[2]); + ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); + int type_bytes = access_ptr->args[0]->dtype.bytes(); + bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes; + } + StmtExprVisitor::VisitExpr_(call); + } + + void VisitStmt_(const ForNode* op) final { + PrimExpr old_loop_evtents = loop_extents; + loop_extents *= op->extent; + StmtExprVisitor::VisitStmt_(op); + loop_extents = old_loop_evtents; + } + + void VisitExpr_(const BufferLoadNode* op) final { + has_simt_copy = true; + StmtExprVisitor::VisitExpr_(op); + } + + bool has_simt_copy; + PrimExpr bulk_copy_bytes; + PrimExpr loop_extents; +}; + +// Rewrite the producer Stmt to use the correct barrier index +class MbarrierRewriter : public StmtExprMutator { + public: + static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) { + MbarrierRewriter rewriter; + rewriter.producer_barrier_idx_ = barrier_id; + return rewriter(stmt); + } + + private: + PrimExpr VisitExpr_(const CallNode* op) final { + auto call = Downcast(StmtExprMutator::VisitExpr_(op)); + if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { + Call access_ptr = Downcast(call->args[2]); + ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); + call.CopyOnWrite()->args.Set(1, makeGetBarrier(producer_barrier_idx_)); + } + return call; + } + PrimExpr producer_barrier_idx_; +}; + + +class ThreadIdxRewriter : public StmtExprMutator { + public: + static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced) { + auto rewriter = ThreadIdxRewriter(thread_var, replaced); + return rewriter(stmt); + } + + private: + ThreadIdxRewriter(Var thread_var, PrimExpr replaced) + : thread_var_(thread_var), replaced_(replaced) {} + + PrimExpr VisitExpr_(const VarNode* var) final { + if (var == thread_var_.get()) { + return replaced_; + } else { + return StmtExprMutator::VisitExpr_(var); + } + } + + Var thread_var_; + PrimExpr replaced_; +}; + +Block MakeGroupBlock(const Stmt& stmt, const Map& annotations) { + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ stmt, + /*init=*/{}, /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*annotations=*/annotations); + return block; +} + +class GroupOpRewriter : public StmtExprMutator { + public: + GroupOpRewriter(Array>& group_info) : group_info_(group_info) {} + + private: + Stmt VisitStmt_(const ForNode* op) final { + Map annotations; + annotations.Set(String("stmt_group"), Integer(1)); + auto original_node = (op->body).as(); + if (!original_node) { + return GetRef(op); + } + Array new_body; + for (size_t i = 0; i < group_info_.size(); i++) { + if (group_info_[i].size() == 0) continue; + Array block_stmt; + for (size_t j = 0; j < group_info_[i].size(); j++) { + ICHECK(group_info_[i][j].as()); + int index = static_cast(group_info_[i][j].as()->value); + ICHECK(original_node->seq[index].as()); + auto block = original_node->seq[index].as(); + // TODO: handle nested seqstmt + block_stmt.push_back(block->body); + } + new_body.push_back( + MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations)); + } + For new_for = For(op->loop_var, op->min, op->extent, op->kind, new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)), op->thread_binding, op->annotations); + return new_for; + } + + Array> group_info_; +}; +class WSCodeEmitter : public StmtMutator { + public: + WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv, + Map buffer_data_to_buffer, const WarpSpecializedRoleMarker& marker) + : is_emitting_producer_(is_emitting_producer), + buffer_data_to_buffer_(buffer_data_to_buffer), + marker_(marker), + thread_var_(thread_iv->var) {} + + private: + template + Stmt FilterByRole(const NodeType* op) { + Role role = marker_.GetRole(op); + if (role == Role::kBoth) + return StmtMutator::VisitStmt_(op); + else if ((role == Role::kProducer) == is_emitting_producer_) + return GetRef(op); + else + return Evaluate(0); + } + + // TODO: only need to add block for ops in the loop + Stmt VisitStmt_(const SeqStmtNode* op) final { + bool has_producer = false; + for (auto stmt : op->seq) { + if (marker_.GetRole(stmt) == Role::kProducer) { + has_producer = true; + break; + } + } + bool need_producer_sync = has_producer && marker_.GetRole(op) == Role::kBoth; + if (!need_producer_sync) return FilterByRole(op); + + auto seq_transformed = op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); }); + + auto map = ExtractSyncPattern(op->seq); + Array new_body; + Map annotations; + annotations.Set(String("stmt_group"), Integer(1)); + + if (is_emitting_producer_) { // producer case + ProducerTraitsCollector collector; + for (int i = 0; i < static_cast(op->seq.size()); i++) { + Array block_stmt = {}; + if (marker_.GetRole(op->seq[i]) == Role::kConsumer) continue; + if (marker_.GetRole(op->seq[i]) == Role::kBoth) { + block_stmt.push_back(seq_transformed[i]); + new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations)); + continue; + } + if (map.acquire[i] != -1) { + PrimExpr acquire_barrier_id = stage_ + num_barriers_ + num_stages_ * map.acquire[i]; + PrimExpr parity = + map.is_loop_dependency(map.acquire[i]) ? bitwise_xor(parity_, 1) : parity_; + block_stmt.push_back(makeParityWait(acquire_barrier_id, parity)); + } + ICHECK(map.release[i] >= 0); + PrimExpr release_barrier_id = stage_ + num_barriers_ + num_stages_ * map.release[i]; + auto stmt = MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id); + collector.Collect(stmt); + if (!is_zero(collector.BulkCopyBytes())) { + auto expect_tx = IfThenElse(EQ(thread_var_, 0), + makeExpectTX(release_barrier_id, collector.BulkCopyBytes())); + block_stmt.push_back(expect_tx); + } + block_stmt.push_back(stmt); + if (collector.HasSimtCopy() > 0) { + block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id)); + } + if (map.release_after[i]) { + block_stmt.push_back(makeArriveBarrier(release_barrier_id)); + for (int j = 0; j < num_stages_; j++) { + released_barrier_.insert(j + num_barriers_ + num_stages_ * map.release[i]); + } + } + collector.Clear(); + new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations)); + } + } else { // consumer case + for (int i = 0; i < static_cast(op->seq.size()); i++) { + Array block_stmt = {}; + if (marker_.GetRole(op->seq[i]) == Role::kProducer) continue; + if (map.acquire[i] != -1) { + PrimExpr acquire_barrier_id = stage_ + num_barriers_ + num_stages_ * map.acquire[i]; + PrimExpr parity = + map.is_loop_dependency(map.acquire[i]) ? bitwise_xor(parity_, 1) : parity_; + block_stmt.push_back(makeParityWait(acquire_barrier_id, parity)); + } + block_stmt.push_back(seq_transformed[i]); + if (map.release_after[i]) { + PrimExpr release_barrier_id = stage_ + num_barriers_ + num_stages_ * map.release[i]; + block_stmt.push_back(makeArriveBarrier(release_barrier_id)); + for (int j = 0; j < num_stages_; j++) { + released_barrier_.insert(j + num_barriers_ + num_stages_ * map.release[i]); + } + } + new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations)); + } + } + + num_barriers_ += map.patterns.size() * num_stages_; + + ICHECK(new_body.size() > 0); + return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)); + } + + Stmt VisitStmt_(const ForNode* op) final { + int num_stages = 1; + auto num_stages_anno = op->annotations.Get("num_stages"); + if (num_stages_anno.defined()) { + ICHECK(num_stages_anno.as()); + num_stages = static_cast(num_stages_anno.as()->value); + ICHECK(num_stages_ == 1) << "Nested pipeline not supported."; + } + + PrimExpr parity_before = std::move(parity_); + PrimExpr stage_before = std::move(stage_); + int num_stages_before = num_stages_; + + num_stages_ = num_stages; + stage_ = FloorMod(op->loop_var - op->min, num_stages); + parity_ = + FloorMod(parity_before * op->extent + FloorDiv(op->loop_var - op->min, num_stages), 2); + + auto result = FilterByRole(op); + + parity_ = std::move(parity_before); + stage_ = std::move(stage_before); + num_stages_ = num_stages_before; + + // remove pipeline annotation + auto for_node = result.as(); + if (result.as()) { + auto for_node = Downcast(result); + for_node.CopyOnWrite()->annotations.erase("num_stages"); + if (is_emitting_producer_) { + for_node.CopyOnWrite()->annotations.erase("software_pipeline_order"); + for_node.CopyOnWrite()->annotations.erase("software_pipeline_stage"); + } + auto group_info_anno = op->annotations.Get("software_pipeline_group"); + if (is_emitting_producer_ || !group_info_anno.defined()) { + return for_node; + } + auto group_info = + Downcast>>(op->annotations.at("software_pipeline_group")); + GroupOpRewriter group_op_rewriter(group_info); + for_node.CopyOnWrite()->annotations.erase("software_pipeline_group"); + Stmt grouped_for_node = group_op_rewriter(for_node); + return grouped_for_node; + } + return result; + } + + Stmt VisitStmt_(const IfThenElseNode* op) final { return FilterByRole(op); } + Stmt VisitStmt_(const EvaluateNode* op) final { return FilterByRole(op); } + Stmt VisitStmt_(const AttrStmtNode* op) final { return FilterByRole(op); } + Stmt VisitStmt_(const BufferStoreNode* op) final { return FilterByRole(op); } + Stmt VisitStmt_(const LetStmtNode* op) final { return FilterByRole(op); } + Stmt VisitStmt_(const AssertStmtNode* op) final { return FilterByRole(op); } + Stmt VisitStmt_(const BlockNode* op) final { + ICHECK(0); + return Stmt(); + } + Stmt VisitStmt_(const BlockRealizeNode* op) final { + ICHECK(0); + return Stmt(); + } + + struct SyncPattern { + int release_idx, acquire_idx; + }; + + struct SyncPatternMap { + std::vector acquire; + std::vector release; + std::vector release_after; + std::vector patterns; + bool is_loop_dependency(int i) { + // return if the acquire is based on release in the previous iteration + return patterns[i].release_idx > patterns[i].acquire_idx; + } + }; + + std::vector CreateBaseSyncPairs(Array seq_stmt, + const std::vector& is_producer) { + const int n = seq_stmt.size(); + std::vector> reads, writes; + reads.reserve(n); + writes.reserve(n); + for (int i = 0; i < n; i++) { + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", + /*body*/ seq_stmt[i]); + auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_); + std::set read_set, write_set; + for (auto region : access[0]) read_set.insert(region->buffer.get()); + for (auto region : access[1]) write_set.insert(region->buffer.get()); + reads.push_back(std::move(read_set)); + writes.push_back(std::move(write_set)); + } + + auto intersect_fn = [](const std::set& lhs, + const std::set& rhs) { + for (auto ptr : lhs) + if (rhs.count(ptr)) return true; + return false; + }; + + std::vector sync_patterns; + // producer_release consumer_acquire, + // inject before the first consumer stmt for each producer + for (int i = 0; i < n; i++) { + for (int j = i + 1; j < n; j++) { + if (is_producer[i] != is_producer[j] && + (intersect_fn(writes[i], reads[j]) || intersect_fn(reads[i], writes[j]))) { + sync_patterns.push_back({i, j}); + break; + } + } + } + + // consumer_release producer_acquire + // valid when is_loop is true + // inject before the earlest producer stmt for each consumer + bool in_loop = !is_zero(parity_); + if (in_loop) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < i; j++) { + if (is_producer[i] != is_producer[j] && + (intersect_fn(writes[i], reads[j]) || intersect_fn(reads[i], writes[j]))) { + sync_patterns.push_back({i, j}); + break; + } + } + } + } + + return sync_patterns; + } + + static std::vector RemoveUnusedSyncPatterns( + const std::vector& sync_patterns, const std::vector& is_producer) { + /* + Simplify multiple release-acquire pairs into one + ------------------ + Produce(A) + Produce(B) + Consume(A, B) + ------------------ + [(0, 2), (1, 2), (2, 0)] -> [(1, 2), (2, 0)] + + Or + ------------------ + Produce(A, B) + Consume(A) + Consume(B) + ------------------ + [(0, 1), (1, 0), (2, 0)] -> [(0, 1), (2, 0)] + */ + int M = sync_patterns.size(); + std::vector removed(M, false); + for (int i = 0; i < M; i++) { + for (int j = 0; j < M; j++) { + if (is_producer[sync_patterns[i].acquire_idx] == + is_producer[sync_patterns[j].acquire_idx] && + sync_patterns[i].acquire_idx >= sync_patterns[j].acquire_idx && + sync_patterns[i].release_idx < sync_patterns[j].release_idx) + removed[i] = true; + } + } + + std::vector sync_pattern_cleaned; + sync_pattern_cleaned.reserve(M); + for (int i = 0; i < M; i++) + if (!removed[i]) sync_pattern_cleaned.push_back(sync_patterns[i]); + + return sync_pattern_cleaned; + } + + SyncPatternMap ExtractSyncPattern(Array seq_stmt) { + size_t num_stmts = seq_stmt.size(); + std::vector is_producer; + is_producer.reserve(num_stmts); + for (auto stmt : seq_stmt) { + is_producer.push_back(marker_.GetRole(stmt) == Role::kProducer); + } + + auto sync_patterns_base = CreateBaseSyncPairs(seq_stmt, is_producer); + auto sync_patterns = RemoveUnusedSyncPatterns(sync_patterns_base, is_producer); + + // for (auto pattern : sync_patterns) { + // std::cout << pattern.release_idx << " " << pattern.acquire_idx << std::endl; + // } + + SyncPatternMap map; + map.patterns = sync_patterns; + map.acquire.resize(num_stmts, -1); + map.release.resize(num_stmts, -1); + map.release_after.resize(num_stmts, false); + for (size_t i = 0; i < sync_patterns.size(); i++) { + map.acquire[sync_patterns[i].acquire_idx] = i; + map.release[sync_patterns[i].release_idx] = i; + map.release_after[sync_patterns[i].release_idx] = true; + } + + int cur_consumer_barrier = -1, cur_producer_barrier = -1; + for (int i = num_stmts - 1; i >= 0; i--) { + if (is_producer[i]) { + if (map.release[i] == -1) { + map.release[i] = cur_producer_barrier; + } else { + cur_producer_barrier = map.release[i]; + } + } else { + if (map.release[i] == -1) { + map.release[i] = cur_consumer_barrier; + } else { + cur_consumer_barrier = map.release[i]; + } + } + } + return map; + } + + const bool is_emitting_producer_; + Map buffer_data_to_buffer_; + std::unordered_set released_barrier_; + const WarpSpecializedRoleMarker& marker_; + + int num_barriers_ = 0; + PrimExpr parity_ = 0; + PrimExpr stage_ = 0; + int num_stages_ = 1; + Var thread_var_; + friend class WarpSpecializedRewriter; +}; + +class WarpSpecializedRewriter : public StmtExprMutator { + public: + static PrimFunc Substitute(PrimFunc f) { + auto T = WarpSpecializedRewriter(); + T.buffer_lca_ = DetectBufferAccessLCA(f); + for (auto [buffer, _] : T.buffer_lca_) T.buffer_data_to_buffer_.Set(buffer->data, buffer); + f.CopyOnWrite()->body = T(f->body); + return f; + } + + private: + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == tir::attr::thread_extent && + Downcast(op->node)->thread_tag == "threadIdx.x") { + thread_iv_ = Downcast(op->node); + AttrStmt attr_stmt = Downcast(StmtExprMutator::VisitStmt_(op)); + if (updated_thread_extent_.defined()) { + thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()}; + attr_stmt.CopyOnWrite()->node = thread_iv_; + attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value(); + } + thread_iv_ = {}; + return attr_stmt; + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + BlockRealize block_realize = Downcast(StmtExprMutator::VisitStmt_(op)); + if (!thread_iv_.defined()) { + return block_realize; + } + ICHECK(!updated_thread_extent_.defined()); + + Block block = block_realize->block; + WarpSpecializedRoleMarker marker(buffer_data_to_buffer_); + marker(block); + if (!marker.HasProducer()) { + // Cannot detect any producer here, directly return. + return block_realize; + } + + WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker); + WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker); + Stmt producer_code = producer(block->body); + Stmt consumer_code = consumer(block->body); + + PrimExpr consumer_thread_extent = thread_iv_->dom->extent; + PrimExpr producer_thread_extent = thread_iv_->dom->extent; + // Need one warp-group for bulk-copy only case + if (!marker.HasSimtCopy()) producer_thread_extent = 128; + + // TODO: estimate the correct reg usage. + auto inc_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), {240, 1})); + auto dec_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), {24, 0})); + + producer_code = SeqStmt({dec_reg_stmt, producer_code}); + consumer_code = SeqStmt({inc_reg_stmt, consumer_code}); + + producer_code = ThreadIdxRewriter::Rewrite(producer_code, thread_iv_->var, + thread_iv_->var - consumer_thread_extent); + updated_thread_extent_ = consumer_thread_extent + producer_thread_extent; + + ICHECK(producer.num_barriers_ == consumer.num_barriers_) + << producer.num_barriers_ << " " << consumer.num_barriers_; + int num_barriers = consumer.num_barriers_; + Array barrier_num_threads; + barrier_num_threads.reserve(num_barriers); + for (int i = 0; i < num_barriers; i++) { + PrimExpr arrive_thread_count = + producer.released_barrier_.count(i) ? producer_thread_extent : consumer_thread_extent; + barrier_num_threads.push_back(arrive_thread_count); + } + + Stmt init_barrier = + Evaluate(Call(DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads)); + Stmt body = + IfThenElse(GE(thread_iv_->var, consumer_thread_extent), producer_code, consumer_code); + // Add an attr here to handle the partial thread count in THreadSync pass. + Array ws_partition = {Downcast(producer_thread_extent), + Downcast(consumer_thread_extent)}; + body = AttrStmt(ws_partition, "kWarpSpecializationScope", 0, body); + + block.CopyOnWrite()->body = SeqStmt({init_barrier, body}); + block_realize.CopyOnWrite()->block = block; + return block_realize; + } + + WarpSpecializedRewriter() = default; + + Map buffer_data_to_buffer_; + Map> buffer_lca_; + Map buffer_remap_; + IterVar thread_iv_; + Optional updated_thread_extent_; +}; + +using namespace tir::transform; + +tvm::transform::Pass WarpSpecialized() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return WarpSpecializedRewriter::Substitute(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {}); +} + +TVM_REGISTER_GLOBAL("tl.WarpSpecialized").set_body_typed(WarpSpecialized); + +} // namespace tl +} // namespace tvm diff --git a/tl_scripts/mha_pipeline.py b/tl_scripts/mha_pipeline.py index d515427be6f6..16d38c14dda0 100644 --- a/tl_scripts/mha_pipeline.py +++ b/tl_scripts/mha_pipeline.py @@ -102,7 +102,7 @@ def main( V: T.Buffer(shape, dtype), Output: T.Buffer(shape, dtype), ): - with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128 * 2) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -121,21 +121,14 @@ def main( T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - MMA0(K, Q_shared, K_shared, acc_s, 0, by, bz) - Softmax(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) - loop_range = ( T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N) ) - for k in T.Pipelined(loop_range, num_stages=1): - if k < loop_range - 1: - MMA0(K, Q_shared, K_shared, acc_s, k + 1, by, bz) - + # Body + for k in T.Pipelined(loop_range, num_stages=2, order=[0,2,1], stage=[0,0,1], group=[[0,1], [2,3,4,5,6,7,8,9,10,11], [12]]): + MMA0(K, Q_shared, K_shared, acc_s, k, by, bz) + Softmax(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) - - if k < loop_range - 1: - Softmax(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) - for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) @@ -150,14 +143,14 @@ def ref_program(Q, K, V, casual): if __name__ == "__main__": - BATCH, H, N_CTX, D_HEAD = 64, 16, 4096, 64 + BATCH, H, N_CTX, D_HEAD = 8, 8, 2048, 256 casual = False flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD total_flops = 2 * flops_per_matmul if casual: total_flops *= 0.5 - BLOCK_M = 64 - BLOCK_N = 64 if D_HEAD <= 128 else 32 + BLOCK_M = 128 + BLOCK_N = 80 # if D_HEAD <= 128 else 32 program = flashattn(BATCH, H, N_CTX, D_HEAD, casual, BLOCK_M, BLOCK_N) ref_program = partial(ref_program, casual=casual) mod, params = tl.lower(program) @@ -167,6 +160,6 @@ def ref_program(Q, K, V, casual): latency = mod.do_bench(ref_program, warmup=500) print("{:.2f} ms".format(latency)) print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) - latency = mod.do_bench(mod) + latency = mod.do_bench(mod, n_warmup=10, n_repeat=10) print("{:.2f} ms".format(latency)) print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) \ No newline at end of file From 13ed600ac2a3f124368f599384e7b4f32bd50ee3 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Mon, 9 Sep 2024 09:30:24 +0000 Subject: [PATCH 15/23] [tl] Update tl_verify --- tl_verify/cuda_interface.cpp | 33 ++- tl_verify/fa_kernel.cu | 515 ++++++++++++++++++----------------- tl_verify/fa_kernel.hpp | 1 + tl_verify/main.py | 46 +++- tl_verify/setup.py | 87 +++++- 5 files changed, 405 insertions(+), 277 deletions(-) diff --git a/tl_verify/cuda_interface.cpp b/tl_verify/cuda_interface.cpp index 1d2b7fcce7ec..290187a527df 100644 --- a/tl_verify/cuda_interface.cpp +++ b/tl_verify/cuda_interface.cpp @@ -5,12 +5,12 @@ #include #include "fa_kernel.hpp" -void main_kernel_launcher(at::Tensor Q, at::Tensor K, at::Tensor V, at::Tensor output); +void main_kernel_launcher(at::Tensor Q, at::Tensor K, at::Tensor V, at::Tensor output, bool causal); void main_kernel_launcher_no_tma(at::Tensor Q, at::Tensor K, at::Tensor V, at::Tensor output); -at::Tensor kernel_function(at::Tensor Q, at::Tensor K, at::Tensor V) { +at::Tensor kernel_function(at::Tensor Q, at::Tensor K, at::Tensor V, bool causal) { at::Tensor output = torch::empty_like(Q); - main_kernel_launcher(Q, K, V, output); + main_kernel_launcher(Q, K, V, output, causal); return output; } @@ -25,8 +25,31 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("kernel_function_no_tma", &kernel_function_no_tma, "FA Kernel Launcher"); } -void main_kernel_launcher(at::Tensor Q, at::Tensor K, at::Tensor V, at::Tensor output) { - host_function(Flash_fwd_params{Q.data_ptr(), K.data_ptr(), V.data_ptr(), output.data_ptr(), Q.size(0), Q.size(1), Q.size(2), Q.size(3), 64, 64}); +void main_kernel_launcher(at::Tensor Q, at::Tensor K, at::Tensor V, at::Tensor output, bool causal) { + int batch = Q.size(0); + int seq_len = Q.size(1); + int heads = Q.size(2); + int dim = Q.size(3); + int block_M = 0; + int block_N = 0; + int threads = 0; + + if (dim == 64) { + block_M = 192; + block_N = 128; + threads = 16 * 32; + } else if (dim == 128) { + block_M = 128; + block_N = causal ? 128 : 176; + threads = 12 * 32; + } else if (dim == 256) { + block_M = 128; + block_N = 80; + threads = 12 * 32; + } else { + throw std::invalid_argument("Invalid dimension"); + } + host_function(Flash_fwd_params{Q.data_ptr(), K.data_ptr(), V.data_ptr(), output.data_ptr(), batch, seq_len, heads, dim, block_M, block_N, threads}); } void main_kernel_launcher_no_tma(at::Tensor Q, at::Tensor K, at::Tensor V, at::Tensor output) { diff --git a/tl_verify/fa_kernel.cu b/tl_verify/fa_kernel.cu index 4709c03f92fa..746fbbf97094 100644 --- a/tl_verify/fa_kernel.cu +++ b/tl_verify/fa_kernel.cu @@ -6,7 +6,212 @@ #include #include "fa_kernel.hpp" -extern "C" __global__ void __launch_bounds__(128) main_kernel(__grid_constant__ const CUtensorMap K_desc, half_t* __restrict__ Output, __grid_constant__ const CUtensorMap Q_desc, __grid_constant__ const CUtensorMap V_desc) { +extern "C" __global__ void __launch_bounds__(512) main_kernel(__grid_constant__ const CUtensorMap K_desc, __grid_constant__ const CUtensorMap Output_desc, __grid_constant__ const CUtensorMap Q_desc, __grid_constant__ const CUtensorMap V_desc) { + extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; + float acc_o[32]; + float logsum[2]; + float scores_max[2]; + float acc_s[64]; + float scores_max_prev[2]; + float scores_scale[2]; + float scores_sum[2]; + half_t acc_s_cast[64]; + __shared__ uint64_t _mbarrier[11]; + if (((int)threadIdx.x) == 0) { + tl::prefetch_tma_descriptor(Q_desc); + tl::prefetch_tma_descriptor(K_desc); + tl::prefetch_tma_descriptor(V_desc); + tl::prefetch_tma_descriptor(Output_desc); + tl::mbarrier_init(_mbarrier[0], 128); + tl::mbarrier_init(_mbarrier[1], 128); + tl::mbarrier_init(_mbarrier[2], 128); + tl::mbarrier_init(_mbarrier[3], 128); + tl::mbarrier_init(_mbarrier[4], 384); + tl::mbarrier_init(_mbarrier[5], 384); + tl::mbarrier_init(_mbarrier[6], 384); + tl::mbarrier_init(_mbarrier[7], 384); + tl::mbarrier_init(_mbarrier[8], 128); + tl::mbarrier_init(_mbarrier[9], 384); + tl::mbarrier_init(_mbarrier[10], 384); + } + __syncthreads(); + if (384 <= ((int)threadIdx.x)) { + tl::warpgroup_reg_dealloc<32>(); + if (((int)threadIdx.x) == 384) { + tl::mbarrier_expect_tx(_mbarrier[8], 24576); + } + if (((int)threadIdx.x) == 384) { + tl::tma_load(Q_desc, _mbarrier[8], (&(((half_t*)buf_dyn_shmem)[0])), 0, ((int)blockIdx.y), (((int)blockIdx.x) * 192), ((int)blockIdx.z)); + } + tl::mbarrier_arrive(_mbarrier[8]); + for (int k = 0; k < 4; ++k) { + tl::mbarrier_wait(_mbarrier[((k & 1) + 4)], ((k >> 1) ^ 1)); + if (((int)threadIdx.x) == 384) { + tl::mbarrier_expect_tx(_mbarrier[(k & 1)], 16384); + } + if (((int)threadIdx.x) == 384) { + tl::tma_load(K_desc, _mbarrier[(k & 1)], (&(((half_t*)buf_dyn_shmem)[(((k & 1) * 8192) + 12288)])), 0, ((int)blockIdx.y), (k * 128), ((int)blockIdx.z)); + } + tl::mbarrier_arrive(_mbarrier[(k & 1)]); + tl::mbarrier_wait(_mbarrier[((k & 1) + 6)], ((k >> 1) ^ 1)); + if (((int)threadIdx.x) == 384) { + tl::mbarrier_expect_tx(_mbarrier[((k & 1) + 2)], 16384); + } + if (((int)threadIdx.x) == 384) { + tl::tma_load(V_desc, _mbarrier[((k & 1) + 2)], (&(((half_t*)buf_dyn_shmem)[(((k & 1) * 8192) + 28672)])), 0, ((int)blockIdx.y), (k * 128), ((int)blockIdx.z)); + } + tl::mbarrier_arrive(_mbarrier[((k & 1) + 2)]); + } + } else { + tl::warpgroup_reg_alloc<160>(); + #pragma unroll + for (int i = 0; i < 32; ++i) { + acc_o[i] = 0.000000e+00f; + } + #pragma unroll + for (int i_1 = 0; i_1 < 2; ++i_1) { + logsum[i_1] = 0.000000e+00f; + } + #pragma unroll + for (int i_2 = 0; i_2 < 2; ++i_2) { + scores_max[i_2] = -CUDART_INF_F; + } + tl::fence_proxy_async(); + tl::mbarrier_wait(_mbarrier[8], 0); + #pragma unroll + for (int i_3 = 0; i_3 < 64; ++i_3) { + acc_s[i_3] = 0.000000e+00f; + } + tl::fence_proxy_async(); + tl::mbarrier_wait(_mbarrier[0], 0); + tl::gemm_ss<192, 128, 64, 12, 1, 0, 1>((&(((half_t*)buf_dyn_shmem)[0])), (&(((half_t*)buf_dyn_shmem)[12288])), (&(acc_s[0]))); + tl::mbarrier_arrive(_mbarrier[4]); + #pragma unroll + for (int i_4 = 0; i_4 < 2; ++i_4) { + scores_max_prev[i_4] = scores_max[i_4]; + } + #pragma unroll + for (int i_5 = 0; i_5 < 2; ++i_5) { + scores_max[i_5] = -CUDART_INF_F; + } + #pragma unroll + for (int i_6 = 0; i_6 < 2; ++i_6) { + #pragma unroll + for (int rv = 0; rv < 32; ++rv) { + scores_max[i_6] = max(scores_max[i_6], acc_s[((((rv & 15) * 4) + (i_6 * 2)) + (rv >> 4))]); + } + scores_max[i_6] = tl::AllReduce::run(scores_max[i_6]); + } + #pragma unroll + for (int i_7 = 0; i_7 < 2; ++i_7) { + scores_scale[i_7] = exp2f(((scores_max_prev[i_7] * 1.803369e-01f) - (scores_max[i_7] * 1.803369e-01f))); + } + #pragma unroll + for (int i_8 = 0; i_8 < 64; ++i_8) { + acc_s[i_8] = exp2f(((acc_s[i_8] * 1.803369e-01f) - (scores_max[((i_8 & 3) >> 1)] * 1.803369e-01f))); + } + #pragma unroll + for (int i_9 = 0; i_9 < 2; ++i_9) { + scores_sum[i_9] = 0.000000e+00f; + #pragma unroll + for (int rv_1 = 0; rv_1 < 32; ++rv_1) { + scores_sum[i_9] = (scores_sum[i_9] + acc_s[((((rv_1 & 15) * 4) + (i_9 * 2)) + (rv_1 >> 4))]); + } + scores_sum[i_9] = tl::AllReduce::run(scores_sum[i_9]); + } + #pragma unroll + for (int i_10 = 0; i_10 < 2; ++i_10) { + logsum[i_10] = ((logsum[i_10] * scores_scale[i_10]) + scores_sum[i_10]); + } + #pragma unroll + for (int i_11 = 0; i_11 < 32; ++i_11) { + acc_o[i_11] = (acc_o[i_11] * scores_scale[((i_11 & 3) >> 1)]); + } + #pragma unroll + for (int i_12 = 0; i_12 < 64; ++i_12) { + acc_s_cast[i_12] = ((half_t)acc_s[i_12]); + } + #pragma unroll 1 +for (int k_1 = 0; k_1 < 3; ++k_1) { + #pragma unroll + for (int i_13 = 0; i_13 < 64; ++i_13) { + acc_s[i_13] = 0.000000e+00f; + } + tl::fence_proxy_async(); + tl::mbarrier_wait(_mbarrier[((k_1 + 1) & 1)], ((k_1 + 1) >> 1)); + tl::gemm_ss<192, 128, 64, 12, 1, 0, 1,-1>((&(((half_t*)buf_dyn_shmem)[0])), (&(((half_t*)buf_dyn_shmem)[((((k_1 + 1) & 1) * 8192) + 12288)])), (&(acc_s[0]))); + + tl::mbarrier_wait(_mbarrier[((k_1 & 1) + 2)], (k_1 >> 1)); + tl::gemm_rs<192, 64, 128, 12, 1, 0, 0,-1>((&(acc_s_cast[0])), (&(((half_t*)buf_dyn_shmem)[(((k_1 & 1) * 8192) + 28672)])), (&(acc_o[0]))); + + cute::warpgroup_wait<1>(); + tl::mbarrier_arrive(_mbarrier[(((k_1 + 1) & 1) + 4)]); + #pragma unroll + for (int i_14 = 0; i_14 < 2; ++i_14) { + scores_max_prev[i_14] = scores_max[i_14]; + } + #pragma unroll + for (int i_15 = 0; i_15 < 2; ++i_15) { + scores_max[i_15] = -CUDART_INF_F; + } + #pragma unroll + for (int i_16 = 0; i_16 < 2; ++i_16) { + #pragma unroll + for (int rv_2 = 0; rv_2 < 32; ++rv_2) { + scores_max[i_16] = max(scores_max[i_16], acc_s[((((rv_2 & 15) * 4) + (i_16 * 2)) + (rv_2 >> 4))]); + } + scores_max[i_16] = tl::AllReduce::run(scores_max[i_16]); + } + #pragma unroll + for (int i_17 = 0; i_17 < 2; ++i_17) { + scores_scale[i_17] = exp2f(((scores_max_prev[i_17] * 1.803369e-01f) - (scores_max[i_17] * 1.803369e-01f))); + } + #pragma unroll + for (int i_18 = 0; i_18 < 64; ++i_18) { + acc_s[i_18] = exp2f(((acc_s[i_18] * 1.803369e-01f) - (scores_max[((i_18 & 3) >> 1)] * 1.803369e-01f))); + } + #pragma unroll + for (int i_19 = 0; i_19 < 2; ++i_19) { + scores_sum[i_19] = 0.000000e+00f; + #pragma unroll + for (int rv_3 = 0; rv_3 < 32; ++rv_3) { + scores_sum[i_19] = (scores_sum[i_19] + acc_s[((((rv_3 & 15) * 4) + (i_19 * 2)) + (rv_3 >> 4))]); + } + scores_sum[i_19] = tl::AllReduce::run(scores_sum[i_19]); + } + #pragma unroll + for (int i_20 = 0; i_20 < 2; ++i_20) { + logsum[i_20] = ((logsum[i_20] * scores_scale[i_20]) + scores_sum[i_20]); + } + cute::warpgroup_wait<0>(); + tl::mbarrier_arrive(_mbarrier[((k_1 & 1) + 6)]); + #pragma unroll + for (int i_21 = 0; i_21 < 32; ++i_21) { + acc_o[i_21] = (acc_o[i_21] * scores_scale[((i_21 & 3) >> 1)]); + } + #pragma unroll + for (int i_22 = 0; i_22 < 64; ++i_22) { + acc_s_cast[i_22] = ((half_t)acc_s[i_22]); + } + } + tl::mbarrier_wait(_mbarrier[3], 1); + tl::gemm_rs<192, 64, 128, 12, 1, 0, 0>((&(acc_s_cast[0])), (&(((half_t*)buf_dyn_shmem)[36864])), (&(acc_o[0]))); + tl::mbarrier_arrive(_mbarrier[7]); + #pragma unroll + for (int i_23 = 0; i_23 < 32; ++i_23) { + acc_o[i_23] = (acc_o[i_23] / logsum[((i_23 & 3) >> 1)]); + } + tl::syncthreads_partial(_mbarrier[9]); + #pragma unroll + for (int i_24 = 0; i_24 < 4; ++i_24) { + tl::ptx_stmatrix_x4((&(((half_t*)buf_dyn_shmem)[(((((((int)threadIdx.x) >> 5) * 1024) + ((((int)threadIdx.x) & 15) * 64)) + (i_24 * 16)) + (((((int)threadIdx.x) & 31) >> 4) * 8))])), __pack_half2(((half_t)acc_o[(i_24 * 8)]), ((half_t)acc_o[((i_24 * 8) + 1)])), __pack_half2(((half_t)acc_o[((i_24 * 8) + 2)]), ((half_t)acc_o[((i_24 * 8) + 3)])), __pack_half2(((half_t)acc_o[((i_24 * 8) + 4)]), ((half_t)acc_o[((i_24 * 8) + 5)])), __pack_half2(((half_t)acc_o[((i_24 * 8) + 6)]), ((half_t)acc_o[((i_24 * 8) + 7)]))); + } + tl::fence_proxy_async(); + tl::syncthreads_partial(_mbarrier[10]); + if (((int)threadIdx.x) == 0) { + tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[0])), 0, ((int)blockIdx.y), (((int)blockIdx.x) * 192), ((int)blockIdx.z)); + } + } } template @@ -54,17 +259,17 @@ struct TensorMapArgs { void host_function(Flash_fwd_params params) { int num_m_blocks = (params.seq_len + params.block_M - 1) / params.block_M; dim3 grid(num_m_blocks, params.head, params.batch); - dim3 block(128); - size_t sharedMemSize = (params.block_M + 2 * params.block_N) * params.dim * sizeof(half_t); // 24576; - - // int size = params.batch * params.head * params.seq_len * params.dim * sizeof(half_t); + dim3 block(params.threads); + size_t sharedMemSize = (params.block_M + 4 * params.block_N) * params.dim * sizeof(half_t); CUtensorMap Q_desc = {0}; CUtensorMap K_desc = {0}; CUtensorMap V_desc = {0}; + CUtensorMap O_desc = {0}; TensorMapArgs Q_arg; TensorMapArgs K_arg; TensorMapArgs V_arg; + TensorMapArgs O_arg; Q_arg.map = &Q_desc; Q_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; @@ -75,12 +280,12 @@ void host_function(Flash_fwd_params params) { Q_arg.globalDim[2] = static_cast(params.seq_len); Q_arg.globalDim[3] = static_cast(params.batch); Q_arg.globalStride[0] = static_cast(2); - Q_arg.globalStride[1] = static_cast(128); - Q_arg.globalStride[2] = static_cast(128); - Q_arg.globalStride[3] = static_cast(32768); + Q_arg.globalStride[1] = static_cast(2 * params.dim); + Q_arg.globalStride[2] = static_cast(2 * params.dim * params.head); + Q_arg.globalStride[3] = static_cast(2 * params.dim * params.head * params.seq_len); Q_arg.boxDim[0] = static_cast(64); Q_arg.boxDim[1] = static_cast(1); - Q_arg.boxDim[2] = static_cast(64); + Q_arg.boxDim[2] = static_cast(params.block_M); Q_arg.boxDim[3] = static_cast(1); Q_arg.elementStrides[0] = static_cast(1); Q_arg.elementStrides[1] = static_cast(1); @@ -95,17 +300,17 @@ void host_function(Flash_fwd_params params) { K_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; K_arg.tensorRank = 4; K_arg.globalAddress = params.k_ptr; - K_arg.globalDim[0] = static_cast(64); - K_arg.globalDim[1] = static_cast(1); - K_arg.globalDim[2] = static_cast(256); - K_arg.globalDim[3] = static_cast(1); + K_arg.globalDim[0] = static_cast(params.dim); + K_arg.globalDim[1] = static_cast(params.head); + K_arg.globalDim[2] = static_cast(params.seq_len); + K_arg.globalDim[3] = static_cast(params.batch); K_arg.globalStride[0] = static_cast(2); - K_arg.globalStride[1] = static_cast(128); - K_arg.globalStride[2] = static_cast(128); - K_arg.globalStride[3] = static_cast(32768); + K_arg.globalStride[1] = static_cast(2 * params.dim); + K_arg.globalStride[2] = static_cast(2 * params.dim * params.head); + K_arg.globalStride[3] = static_cast(2 * params.dim * params.head * params.seq_len); K_arg.boxDim[0] = static_cast(64); K_arg.boxDim[1] = static_cast(1); - K_arg.boxDim[2] = static_cast(64); + K_arg.boxDim[2] = static_cast(params.block_N); K_arg.boxDim[3] = static_cast(1); K_arg.elementStrides[0] = static_cast(1); K_arg.elementStrides[1] = static_cast(1); @@ -120,17 +325,17 @@ void host_function(Flash_fwd_params params) { V_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; V_arg.tensorRank = 4; V_arg.globalAddress = params.v_ptr; - V_arg.globalDim[0] = static_cast(64); - V_arg.globalDim[1] = static_cast(1); - V_arg.globalDim[2] = static_cast(256); - V_arg.globalDim[3] = static_cast(1); + V_arg.globalDim[0] = static_cast(params.dim); + V_arg.globalDim[1] = static_cast(params.head); + V_arg.globalDim[2] = static_cast(params.seq_len); + V_arg.globalDim[3] = static_cast(params.batch); V_arg.globalStride[0] = static_cast(2); - V_arg.globalStride[1] = static_cast(128); - V_arg.globalStride[2] = static_cast(128); - V_arg.globalStride[3] = static_cast(32768); + V_arg.globalStride[1] = static_cast(2 * params.dim); + V_arg.globalStride[2] = static_cast(2 * params.dim * params.head); + V_arg.globalStride[3] = static_cast(2 * params.dim * params.head * params.seq_len); V_arg.boxDim[0] = static_cast(64); V_arg.boxDim[1] = static_cast(1); - V_arg.boxDim[2] = static_cast(64); + V_arg.boxDim[2] = static_cast(params.block_N); V_arg.boxDim[3] = static_cast(1); V_arg.elementStrides[0] = static_cast(1); V_arg.elementStrides[1] = static_cast(1); @@ -141,6 +346,31 @@ void host_function(Flash_fwd_params params) { V_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; V_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + O_arg.map = &O_desc; + O_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + O_arg.tensorRank = 4; + O_arg.globalAddress = params.output_ptr; + O_arg.globalDim[0] = static_cast(params.dim); + O_arg.globalDim[1] = static_cast(params.head); + O_arg.globalDim[2] = static_cast(params.seq_len); + O_arg.globalDim[3] = static_cast(params.batch); + O_arg.globalStride[0] = static_cast(2); + O_arg.globalStride[1] = static_cast(2 * params.dim); + O_arg.globalStride[2] = static_cast(2 * params.dim * params.head); + O_arg.globalStride[3] = static_cast(2 * params.dim * params.head * params.seq_len); + O_arg.boxDim[0] = static_cast(64); + O_arg.boxDim[1] = static_cast(1); + O_arg.boxDim[2] = static_cast(params.block_M); + O_arg.boxDim[3] = static_cast(1); + O_arg.elementStrides[0] = static_cast(1); + O_arg.elementStrides[1] = static_cast(1); + O_arg.elementStrides[2] = static_cast(1); + O_arg.elementStrides[3] = static_cast(1); + O_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + O_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; + O_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + O_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + CUresult result; result = cuTensorMapEncodeTiled( Q_arg.map, Q_arg.type, Q_arg.tensorRank, Q_arg.globalAddress, Q_arg.globalDim, Q_arg.globalStride + 1, Q_arg.boxDim, @@ -166,13 +396,24 @@ void host_function(Flash_fwd_params params) { << V_arg.ToDebugString(); } + result = cuTensorMapEncodeTiled( + O_arg.map, O_arg.type, O_arg.tensorRank, O_arg.globalAddress, O_arg.globalDim, O_arg.globalStride + 1, O_arg.boxDim, + O_arg.elementStrides, O_arg.interleave, O_arg.swizzle, O_arg.l2Promotion, O_arg.oobFill); + if (result != CUDA_SUCCESS) { + std::cout << "Failed to initialize the TMA descriptor " << result << std::endl + << O_arg.ToDebugString(); + } + + const int MAXBYTES = 1024 * 226; + cudaFuncSetAttribute(main_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, MAXBYTES); + cudaError_t err = cudaDeviceSynchronize(); if (err != cudaSuccess) { std::cerr << "CUDA device synchronization failed: " << cudaGetErrorString(err) << std::endl; return; } - main_kernel<<>>(K_desc, (half_t*)params.output_ptr, Q_desc, V_desc); + main_kernel<<>>(K_desc, O_desc, Q_desc, V_desc); err = cudaGetLastError(); if (err != cudaSuccess) { @@ -185,224 +426,4 @@ void host_function(Flash_fwd_params params) { std::cerr << "CUDA device synchronization failed: " << cudaGetErrorString(err) << std::endl; return; } -} -// template -// static std::string ArrayToStr(const T* ptr, size_t n) { -// std::stringstream ss; -// ss << "["; -// for (size_t i = 0; i < n; i++) { -// if (i > 0) ss << ", "; -// ss << ptr[i]; -// } -// ss << "]"; -// return ss.str(); -// } - -// struct TensorMapArgs { -// CUtensorMap* map; -// CUtensorMapDataType type; -// cuuint32_t tensorRank; -// void* globalAddress; -// cuuint64_t globalDim[5], globalStride[5]; -// cuuint32_t boxDim[5], elementStrides[5]; -// CUtensorMapInterleave interleave; -// CUtensorMapSwizzle swizzle; -// CUtensorMapL2promotion l2Promotion; -// CUtensorMapFloatOOBfill oobFill; - -// std::string ToDebugString() { -// std::stringstream ss; -// ss << "TMA Desc Addr: " << map << std::endl -// << "format " << type << std::endl -// << "dim " << tensorRank << std::endl -// << "gmem_address " << globalAddress << std::endl -// << "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl -// << "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl -// << "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl -// << "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl -// << "interleave " << interleave << std::endl -// << "swizzle " << swizzle << std::endl -// << "l2Promotion " << l2Promotion << std::endl -// << "oobFill " << oobFill << std::endl; -// return ss.str(); -// } -// }; - -// __global__ void fillWithOnes(void *ptr, size_t size) { -// half_t *data = (half_t *)ptr; -// size_t index = threadIdx.x + blockIdx.x * blockDim.x; -// if (index < size) { -// data[index] = 1; -// } -// } - -// int main() { -// dim3 grid(4); -// dim3 block(129); -// size_t sharedMemSize = 24576; - -// int batch = 1; -// int head = 1; -// int seq_len = 256; -// int dim = 64; -// int size = batch * head * seq_len * dim * sizeof(half_t); - -// void *Q, *K, *V, *d_output; -// void *h_output; -// h_output = (void*)malloc(size); -// cudaMalloc((void**)&Q, size); -// cudaMalloc((void**)&K, size); -// cudaMalloc((void**)&V, size); -// cudaMalloc((void**)&d_output, size); - - -// int threadsPerBlock = 256; -// int blocksPerGrid = (batch * head * seq_len * dim + threadsPerBlock - 1) / threadsPerBlock; -// fillWithOnes<<>>(Q, batch * head * seq_len * dim + threadsPerBlock); -// fillWithOnes<<>>(K, batch * head * seq_len * dim + threadsPerBlock); -// fillWithOnes<<>>(V, batch * head * seq_len * dim + threadsPerBlock); - -// cudaError_t err = cudaDeviceSynchronize(); -// if (err != cudaSuccess) { -// std::cerr << "fillWithOnes failed: " << cudaGetErrorString(err) << std::endl; -// return 1; -// } - -// CUtensorMap Q_desc = {0}; -// CUtensorMap K_desc = {0}; -// CUtensorMap V_desc = {0}; -// TensorMapArgs Q_arg; -// TensorMapArgs K_arg; -// TensorMapArgs V_arg; - -// Q_arg.map = &Q_desc; -// Q_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; -// Q_arg.tensorRank = 4; -// Q_arg.globalAddress = Q; -// Q_arg.globalDim[0] = static_cast(64); -// Q_arg.globalDim[1] = static_cast(1); -// Q_arg.globalDim[2] = static_cast(256); -// Q_arg.globalDim[3] = static_cast(1); -// Q_arg.globalStride[0] = static_cast(2); -// Q_arg.globalStride[1] = static_cast(128); -// Q_arg.globalStride[2] = static_cast(128); -// Q_arg.globalStride[3] = static_cast(32768); -// Q_arg.boxDim[0] = static_cast(64); -// Q_arg.boxDim[1] = static_cast(1); -// Q_arg.boxDim[2] = static_cast(64); -// Q_arg.boxDim[3] = static_cast(1); -// Q_arg.elementStrides[0] = static_cast(1); -// Q_arg.elementStrides[1] = static_cast(1); -// Q_arg.elementStrides[2] = static_cast(1); -// Q_arg.elementStrides[3] = static_cast(1); -// Q_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; -// Q_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; -// Q_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; -// Q_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; - -// K_arg.map = &K_desc; -// K_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; -// K_arg.tensorRank = 4; -// K_arg.globalAddress = K; -// K_arg.globalDim[0] = static_cast(64); -// K_arg.globalDim[1] = static_cast(1); -// K_arg.globalDim[2] = static_cast(256); -// K_arg.globalDim[3] = static_cast(1); -// K_arg.globalStride[0] = static_cast(2); -// K_arg.globalStride[1] = static_cast(128); -// K_arg.globalStride[2] = static_cast(128); -// K_arg.globalStride[3] = static_cast(32768); -// K_arg.boxDim[0] = static_cast(64); -// K_arg.boxDim[1] = static_cast(1); -// K_arg.boxDim[2] = static_cast(64); -// K_arg.boxDim[3] = static_cast(1); -// K_arg.elementStrides[0] = static_cast(1); -// K_arg.elementStrides[1] = static_cast(1); -// K_arg.elementStrides[2] = static_cast(1); -// K_arg.elementStrides[3] = static_cast(1); -// K_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; -// K_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; -// K_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; -// K_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; - -// V_arg.map = &V_desc; -// V_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; -// V_arg.tensorRank = 4; -// V_arg.globalAddress = V; -// V_arg.globalDim[0] = static_cast(64); -// V_arg.globalDim[1] = static_cast(1); -// V_arg.globalDim[2] = static_cast(256); -// V_arg.globalDim[3] = static_cast(1); -// V_arg.globalStride[0] = static_cast(2); -// V_arg.globalStride[1] = static_cast(128); -// V_arg.globalStride[2] = static_cast(128); -// V_arg.globalStride[3] = static_cast(32768); -// V_arg.boxDim[0] = static_cast(64); -// V_arg.boxDim[1] = static_cast(1); -// V_arg.boxDim[2] = static_cast(64); -// V_arg.boxDim[3] = static_cast(1); -// V_arg.elementStrides[0] = static_cast(1); -// V_arg.elementStrides[1] = static_cast(1); -// V_arg.elementStrides[2] = static_cast(1); -// V_arg.elementStrides[3] = static_cast(1); -// V_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; -// V_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; -// V_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; -// V_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; - -// CUresult result; -// result = cuTensorMapEncodeTiled( -// Q_arg.map, Q_arg.type, Q_arg.tensorRank, Q_arg.globalAddress, Q_arg.globalDim, Q_arg.globalStride + 1, Q_arg.boxDim, -// Q_arg.elementStrides, Q_arg.interleave, Q_arg.swizzle, Q_arg.l2Promotion, Q_arg.oobFill); -// if (result != CUDA_SUCCESS) { -// std::cout << "Failed to initialize the TMA descriptor " << result << std::endl -// << Q_arg.ToDebugString(); -// } - -// result = cuTensorMapEncodeTiled( -// K_arg.map, K_arg.type, K_arg.tensorRank, K_arg.globalAddress, K_arg.globalDim, K_arg.globalStride + 1, K_arg.boxDim, -// K_arg.elementStrides, K_arg.interleave, K_arg.swizzle, K_arg.l2Promotion, K_arg.oobFill); -// if (result != CUDA_SUCCESS) { -// std::cout << "Failed to initialize the TMA descriptor " << result << std::endl -// << K_arg.ToDebugString(); -// } - -// result = cuTensorMapEncodeTiled( -// V_arg.map, V_arg.type, V_arg.tensorRank, V_arg.globalAddress, V_arg.globalDim, V_arg.globalStride + 1, V_arg.boxDim, -// V_arg.elementStrides, V_arg.interleave, V_arg.swizzle, V_arg.l2Promotion, V_arg.oobFill); -// if (result != CUDA_SUCCESS) { -// std::cout << "Failed to initialize the TMA descriptor " << result << std::endl -// << V_arg.ToDebugString(); -// } - -// if (err != cudaSuccess) { -// std::cerr << "CUDA device synchronization failed: " << cudaGetErrorString(err) << std::endl; -// return 1; -// } - -// main_kernel<<>>(K_desc, (half_t*)d_output, Q_desc, V_desc); - -// err = cudaGetLastError(); -// if (err != cudaSuccess) { -// std::cerr << "CUDA kernel launch failed: " << cudaGetErrorString(err) << std::endl; -// return 1; -// } - -// err = cudaDeviceSynchronize(); -// if (err != cudaSuccess) { -// std::cerr << "CUDA device synchronization failed: " << cudaGetErrorString(err) << std::endl; -// return 1; -// } - -// cudaMemcpy((void*)h_output, (void*)d_output, size, cudaMemcpyDeviceToHost); - -// std::cout << "CUDA kernel executed successfully." << std::endl; -// for (int i = 0; i < seq_len; i++) { -// for (int j = 0; j < dim; j++) { -// std::cout << ((half_t*)h_output)[i * dim + j] << " "; -// } -// std::cout << std::endl; -// } -// std::cout << std::endl; -// return 0; -// } \ No newline at end of file +} \ No newline at end of file diff --git a/tl_verify/fa_kernel.hpp b/tl_verify/fa_kernel.hpp index 0cd6f228cc3c..9b3103af1d52 100644 --- a/tl_verify/fa_kernel.hpp +++ b/tl_verify/fa_kernel.hpp @@ -18,6 +18,7 @@ struct Flash_fwd_params index_t dim; index_t block_M; index_t block_N; + index_t threads; }; void host_function(Flash_fwd_params params); diff --git a/tl_verify/main.py b/tl_verify/main.py index 0d00d5307af2..9fb8b9ef9e41 100644 --- a/tl_verify/main.py +++ b/tl_verify/main.py @@ -1,6 +1,6 @@ import torch import fa_test -# from flash_attn.flash_attn_interface import flash_attn_func +from flash_attn.flash_attn_interface import flash_attn_func import random import numpy as np @@ -51,19 +51,41 @@ def ref_program(Q, K, V, casual): return acc_o.to(torch.float16) set_seed(42) -batch, seq_len, heads, dim = 1, 256, 1, 64 +causal = False +batch, seq_len, heads, dim = 64, 512, 16, 64 shape = [batch, seq_len, heads, dim] # q = torch.empty(*shape, device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) # k = torch.empty(*shape, device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) -q = torch.ones(*shape, device='cuda', dtype=torch.float16) -k = torch.ones(*shape, device='cuda', dtype=torch.float16) +q = torch.ones(*shape, device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) +k = torch.ones(*shape, device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) v = torch.empty(*shape, device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) -# output = fa_test.kernel_function(q, k, v) -output = fa_test.kernel_function_no_tma(q, k, v) -# ref_output = flash_attn_func(q, k, v, causal=False) -ref_output = ref_program(q, k, v, False) -print(output) -print(ref_output) -# print(ref_output) -# assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2) \ No newline at end of file +output = fa_test.kernel_function(q, k, v, causal) +ref_output = flash_attn_func(q, k, v, causal=False) +# ref_output = ref_program(q, k, v, causal) +assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2) +print("Check: PASSED") + +warmups = 10 +runs = 10 +for _ in range(warmups): + out = fa_test.kernel_function(q, k, v, causal) + +start_event = torch.cuda.Event(enable_timing=True) +end_event = torch.cuda.Event(enable_timing=True) + +start_event.record() + +for _ in range(runs): + out = fa_test.kernel_function(q, k, v, causal) + +end_event.record() +torch.cuda.synchronize() + +latency = start_event.elapsed_time(end_event) + +flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim +total_flops = 2 * flops_per_matmul +print(f"total_flops: {total_flops}") +print(f"TFLOPS: {total_flops / latency * runs * 1e-9}") +print(f"Latency: {latency / runs:.2f} ms") \ No newline at end of file diff --git a/tl_verify/setup.py b/tl_verify/setup.py index d38ed6637c21..f2fd38dcbc64 100644 --- a/tl_verify/setup.py +++ b/tl_verify/setup.py @@ -1,15 +1,33 @@ from setuptools import setup import torch.utils.cpp_extension -from torch.utils.cpp_extension import CUDAExtension, BuildExtension +import subprocess +from packaging.version import parse, Version +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME torch.utils.cpp_extension.CUDAExtension.debug = True -extra_compile_args = { - 'cxx': ['-O3', '-std=c++17'], - 'nvcc': [ - '-arch=sm_90a', - '--use_fast_math', - '-std=c++17', + +def append_nvcc_threads(nvcc_extra_args): + return nvcc_extra_args + ["--threads", "4"] + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + +cc_flag = [] +_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +if bare_metal_version < Version("12.3"): + raise RuntimeError("FA Hopper is only supported on CUDA 12.3 and above") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_90a,code=sm_90a") + +nvcc_flags = [ "-O3", + # "-O0", + "-std=c++17", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_BFLOAT16_OPERATORS__", @@ -18,12 +36,53 @@ "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", - '-I/usr/local/cuda/include', - '-I/home/msra/cy/tvm.tl/src/tl', - '-I/home/msra/cy/tvm.tl/cutlass/include', - '-lcuda', - # '-keep' # Uncomment this line to keep the generated .ptx file - ], + "--use_fast_math", + # "--ptxas-options=-v", # printing out number of registers + "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers + "-lineinfo", + "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging + "-DNDEBUG", # Important, otherwise performance is severely impacted + "-DQBLKSIZE=128", + "-DKBLKSIZE=128", + "-DCTA256", + "-DDQINRMEM", + # "-keep" + ] + +# extra_compile_args = { +# 'cxx': ['-O3', '-std=c++17'], +# 'nvcc': [ +# # '-arch=sm_90a', +# '-gencode arch=compute_90a,code=compute_90a', +# '--use_fast_math', +# '-std=c++17', +# "-O3", +# "-U__CUDA_NO_HALF_OPERATORS__", +# "-U__CUDA_NO_HALF_CONVERSIONS__", +# "-U__CUDA_NO_BFLOAT16_OPERATORS__", +# "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", +# "-U__CUDA_NO_BFLOAT162_OPERATORS__", +# "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", +# "--expt-relaxed-constexpr", +# "--expt-extended-lambda", +# "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers +# '-I/usr/local/cuda/include', +# '-I/home/msra/cy/tvm.tl/src/tl', +# '-I/home/msra/cy/tvm.tl/cutlass/include', +# '-lcuda', +# '-lineinfo', +# "-lnvToolsExt", +# "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging +# "-DNDEBUG", # Important, otherwise performance is severely impacted +# # '-keep' # Uncomment this line to keep the generated .ptx file +# ], +# } + +extra_compile_args = { + "cxx": ["-O3", "-std=c++17"], + "nvcc": append_nvcc_threads( + nvcc_flags + ["-DEXECMODE=0"] + cc_flag + ), } include_dirs = [ @@ -47,3 +106,5 @@ 'build_ext': BuildExtension } ) + +# sudo -E env PATH=$PATH PYTHONPATH=$PYTHONPATH TMPDIR=~/cy/ncu_tmp ncu --set full -k regex:"main_kernel" --launch-count 1 --launch-skip 10 --target-processes application-only --cache-control none --clock-control none --apply-rules yes --import-source yes --check-exit-code yes -f -o reports/tl_8_2048_8_256_false /home/msra/miniconda3/envs/tl/bin/python main.py \ No newline at end of file From 1c6642ca22107e2333989c1b572ca9983697c059 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Wed, 11 Sep 2024 06:21:01 +0000 Subject: [PATCH 16/23] update mha_pipeline.py --- tl_scripts/mha_pipeline.py | 44 +++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/tl_scripts/mha_pipeline.py b/tl_scripts/mha_pipeline.py index 16d38c14dda0..be4de6e8bd6b 100644 --- a/tl_scripts/mha_pipeline.py +++ b/tl_scripts/mha_pipeline.py @@ -48,11 +48,18 @@ def MMA0( K_shared: T.Buffer([block_N, dim], dtype), acc_s: T.Buffer([block_M, block_N], accum_dtype), k: T.int32, + bx: T.int32, by: T.int32, bz: T.int32, ): T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) - T.clear(acc_s) + if is_casual: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype) + ) + else: + T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro @@ -79,15 +86,21 @@ def Softmax( scores_sum: T.Buffer([block_M], accum_dtype), logsum: T.Buffer([block_M], accum_dtype), ): - for i, j in T.Parallel(block_M, dim): - acc_s[i, j] *= scale T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) for i in T.Parallel(block_M): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] @@ -106,6 +119,7 @@ def main( Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) acc_o = T.alloc_fragment([block_M, dim], accum_dtype) @@ -121,17 +135,21 @@ def main( T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) + # loop_range = ( + # T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N) + # ) loop_range = ( - T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N) + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_casual else T.ceildiv(seq_len, block_N) ) - # Body - for k in T.Pipelined(loop_range, num_stages=2, order=[0,2,1], stage=[0,0,1], group=[[0,1], [2,3,4,5,6,7,8,9,10,11], [12]]): - MMA0(K, Q_shared, K_shared, acc_s, k, by, bz) + + for k in T.Pipelined(loop_range, num_stages=1, order=[0,2,1], stage=[0,0,1], group=[[0,1], [2,3,4,5,6,7,8,9,10], [11]]): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) Softmax(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main @@ -143,14 +161,14 @@ def ref_program(Q, K, V, casual): if __name__ == "__main__": - BATCH, H, N_CTX, D_HEAD = 8, 8, 2048, 256 - casual = False + BATCH, H, N_CTX, D_HEAD = 64, 12, 1024, 128 + casual = True flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD total_flops = 2 * flops_per_matmul if casual: total_flops *= 0.5 BLOCK_M = 128 - BLOCK_N = 80 # if D_HEAD <= 128 else 32 + BLOCK_N = 128 # if D_HEAD <= 128 else 32 program = flashattn(BATCH, H, N_CTX, D_HEAD, casual, BLOCK_M, BLOCK_N) ref_program = partial(ref_program, casual=casual) mod, params = tl.lower(program) From 60abe81a5ea6905c6ea03b67d02c3d2e74d1435e Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Sat, 21 Sep 2024 14:48:10 +0000 Subject: [PATCH 17/23] [tl] Update. --- python/tvm/tl/autotuner.py | 49 +++++-- python/tvm/tl/language.py | 16 ++- python/tvm/tl/utils.py | 28 +++- src/tl/ir.cc | 11 +- src/tl/op/builtin.cc | 4 + src/tl/op/builtin.h | 8 ++ src/tl/op/reduce.cc | 8 ++ src/tl/op/reduce.h | 1 + src/tl/tl_templates/gemm_sm90.h | 80 +++++++---- tl_scripts/retnet_example.py | 105 +++++++------- tl_scripts/torch_ref.py | 246 ++++++++++++++++++++++++-------- 11 files changed, 395 insertions(+), 161 deletions(-) diff --git a/python/tvm/tl/autotuner.py b/python/tvm/tl/autotuner.py index 33916e78d493..e9dd4c90755c 100644 --- a/python/tvm/tl/autotuner.py +++ b/python/tvm/tl/autotuner.py @@ -19,8 +19,14 @@ from tvm import tl import inspect from functools import wraps -from typing import Any, Callable, List +from typing import Any, Callable, List, Any +import inspect +import multiprocessing +from tqdm import tqdm +import logging +logging.basicConfig(filename='out.log', filemode='w', level=logging.INFO, + format='%(asctime)s %(levelname)s:%(message)s') class Autotuner: def __init__( self, @@ -45,7 +51,17 @@ def run(self, *args: Any, **kwds: Any) -> Any: # print(f"{name} = {value}") best_latency = 1e8 best_config = None - for config in self.configs: + + def target_fn(pipe, *new_args, **kwds): + try: + latency, ref_latency = self.fn(*new_args, **kwds) + pipe.send((latency, ref_latency)) + except Exception as e: + logging.error(f"Fail on config {new_args} with error: {e}") + pipe.send((1e8, None)) + + progress_bar = tqdm(self.configs, desc="Running configurations") + for config in progress_bar: new_args = [] for name, value in bound_args.arguments.items(): if name not in self.keys: @@ -53,12 +69,24 @@ def run(self, *args: Any, **kwds: Any) -> Any: else: new_args.append(config[name]) new_args = tuple(new_args) - # print("auto-tunner new_args:", new_args) - try: - latency, ref_latency = self.fn(*new_args, **kwds) - except Exception as e: - print("Fail on config ", config, " with error: ", e) + + parent_pipe, child_pipe = multiprocessing.Pipe() + + p = multiprocessing.Process(target=target_fn, args=(child_pipe, *new_args), kwargs=kwds) + p.start() + + p.join(40) + if p.is_alive(): + logging.error(f"Killing config {config} due to timeout.") + p.terminate() + p.join() latency = 1e8 + else: + latency, ref_latency = parent_pipe.recv() + logging.info(f"Config {config} latency: {latency}") + + progress_bar.set_postfix({"best_latency": best_latency}) + if latency < best_latency: best_latency = latency best_config = config @@ -80,7 +108,8 @@ def jit( supply_type: tl.TensorSupplyType = tl.TensorSupplyType.Normal, ref_prog: Callable = None, rtol: float = 1e-5, - atol: float = 1e-5 + atol: float = 1e-5, + profiler: str = "torch" ) -> Callable: def wrapper(fn: Callable): @@ -91,9 +120,9 @@ def decorator(*args, **kwargs) -> float: mod, params = tl.lower(fn(*args, **kwargs)) mod = tl.Profiler(mod, params, out_idx, supply_type) mod.assert_allclose(ref_prog, rtol=rtol, atol=atol) - latency = mod.do_bench(mod.func, warmup = 25) + latency = mod.do_bench(mod.func, n_warmup=10, n_repeat=10, profiler=profiler) if ref_latency_cache is None and ref_prog is not None: - ref_latency_cache = mod.do_bench(ref_prog, warmup = 25) + ref_latency_cache = mod.do_bench(ref_prog, n_warmup=10, n_repeat=10, profiler=profiler) return latency, ref_latency_cache return decorator return wrapper \ No newline at end of file diff --git a/python/tvm/tl/language.py b/python/tvm/tl/language.py index 7b0a098ee849..9acf16511f91 100644 --- a/python/tvm/tl/language.py +++ b/python/tvm/tl/language.py @@ -43,7 +43,15 @@ def Parallel(*extents: tir.PrimExpr): return _ffi_api.Parallel(extents) # type: ignore[attr-defined] # pylint: disable=no-member -def Pipelined(start: tir.PrimExpr, stop: tir.PrimExpr = None, num_stages: int = 0, order: List[int] = [], stage: List[int] = [], group: List[List[int]] = []): +def Pipelined( + start: tir.PrimExpr, + stop: tir.PrimExpr = None, + num_stages: int = 0, + order: List[int] = [], + stage: List[int] = [], + sync: List[List[int]] = [], + group: List[List[int]] = [] + ): """Tools to construct pipelined for loop. Parameters @@ -67,7 +75,7 @@ def Pipelined(start: tir.PrimExpr, stop: tir.PrimExpr = None, num_stages: int = else: start = 0 # type: ignore[attr-defined] # pylint: disable=no-member - return _ffi_api.Pipelined(start, stop, num_stages, order, stage, group) + return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group) @register_object("tl.KernelLaunchFrame") @@ -277,5 +285,9 @@ def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int): return reduce(buffer, out, "sum", dim, True) +def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int): + return reduce(buffer, out, "abssum", dim, True) + + def atomic_add(dst, value): return T.call_extern("handle", "atomicAdd", T.address_of(dst), value) diff --git a/python/tvm/tl/utils.py b/python/tvm/tl/utils.py index 33ade5f702fb..90c3e95ad61f 100644 --- a/python/tvm/tl/utils.py +++ b/python/tvm/tl/utils.py @@ -16,11 +16,14 @@ # under the License. """The profiler and convert to torch utils""" -from typing import Any, List +from typing import Any, List, Literal from enum import Enum from functools import partial import torch +import tvm +from torch.utils.dlpack import to_dlpack +from tvm.runtime import ndarray from tvm.relay import TensorType from tvm.contrib.dlpack import to_pytorch_func @@ -111,10 +114,10 @@ def __init__( super().__init__(mod, params, result_idx) self.supply = get_tensor_supply(supply_type) - def _get_inputs(self): + def _get_inputs(self, with_output=False): ins = [] for i in range(len(self.params)): - if i not in self.result_idx: + if with_output or i not in self.result_idx: ins.append(self.supply(self.params[i])) return ins @@ -159,10 +162,21 @@ def run_once(self, func=None): - def do_bench(self, func: callable, warmup=25, rep=100, n_warmup=0, n_repeat=0): - ins = self._get_inputs() - bench_func = partial(func, *ins) - return do_bench(bench_func, warmup=warmup, rep=rep, _n_warmup=n_warmup, _n_repeat=n_repeat) + def do_bench(self, func: callable, warmup=25, rep=100, n_warmup=0, n_repeat=0, profiler: Literal["torch", "tvm"] = "torch"): + if profiler == "torch": + ins = self._get_inputs() + bench_func = partial(func, *ins) + return do_bench(bench_func, warmup=warmup, rep=rep, _n_warmup=n_warmup, _n_repeat=n_repeat) + elif profiler == "tvm": + ins = self._get_inputs(with_output=True) + time_evaluator = self.mod.time_evaluator( + self.mod.entry_name, tvm.cuda(0), number=rep, repeat=n_repeat + ) + tvm_inputs = [ndarray.from_dlpack(to_dlpack(inp)) for inp in ins] + # Transform Latency to ms + return time_evaluator(*tvm_inputs).mean * 1e3 + else: + raise ValueError(f"Unknown profiler: {profiler}") def do_bench( diff --git a/src/tl/ir.cc b/src/tl/ir.cc index 5d440773e4a1..cde0e180cee8 100644 --- a/src/tl/ir.cc +++ b/src/tl/ir.cc @@ -54,7 +54,15 @@ ForFrame ParallelFor(Array extents) { return ForFrame(n); } -ForFrame PipelinedFor(PrimExpr start, PrimExpr stop, int num_stages, Array order, Array stages, Array> groups) { +ForFrame PipelinedFor( + PrimExpr start, + PrimExpr stop, + int num_stages, + Array order, + Array stages, + Array> sync, + Array> groups + ) { using namespace tvm::tir; ObjectPtr n = make_object(); DataType dtype = stop.dtype(); @@ -68,6 +76,7 @@ ForFrame PipelinedFor(PrimExpr start, PrimExpr stop, int num_stages, Array 0) anno.Set("num_stages", PrimExpr(num_stages)); anno.Set("software_pipeline_order", order); anno.Set("software_pipeline_stage", stages); + anno.Set("software_pipeline_sync", sync); anno.Set("software_pipeline_group", groups); body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, std::move(body), /*thread_binding=*/NullOpt, /*annotations=*/anno); diff --git a/src/tl/op/builtin.cc b/src/tl/op/builtin.cc index 487c460f9561..272e6591e02b 100644 --- a/src/tl/op/builtin.cc +++ b/src/tl/op/builtin.cc @@ -96,6 +96,10 @@ TIR_DEFINE_TL_BUILTIN(SetMaxNReg) .set_num_inputs(2) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(WaitWgmma) + .set_num_inputs(1) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(PackB16Op).set_num_inputs(2).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); } // namespace tl diff --git a/src/tl/op/builtin.h b/src/tl/op/builtin.h index 49e1c21101ee..fa32f7e3c788 100644 --- a/src/tl/op/builtin.h +++ b/src/tl/op/builtin.h @@ -154,6 +154,14 @@ const Op& FenceProxyAsyncOp(); */ const Op& SetMaxNReg(); +/*! + * \brief Wait the previous wgmma to finish + * + * WaitWgmma(num_mma) + * + */ +const Op& WaitWgmma(); + } // namespace tl } // namespace tvm diff --git a/src/tl/op/reduce.cc b/src/tl/op/reduce.cc index c057a123ee08..da32ecf7b762 100644 --- a/src/tl/op/reduce.cc +++ b/src/tl/op/reduce.cc @@ -44,6 +44,8 @@ ReduceOp::ReduceOp(Array args, BufferMap vmap) { dim = args[3].as().value()->value; if (reduce_type == "sum") type = ReduceType::kSum; + else if (reduce_type == "abssum") + type = ReduceType::kAbsSum; else if (reduce_type == "max") type = ReduceType::kMax; else if (reduce_type == "min") @@ -57,6 +59,8 @@ PrimExpr ReduceOp::MakeInitValue() const { switch (type) { case ReduceType::kSum: return make_zero(dst->dtype); + case ReduceType::kAbsSum: + return make_zero(dst->dtype); case ReduceType::kMax: return make_const(dst->dtype, -INFINITY); case ReduceType::kMin: @@ -74,6 +78,8 @@ PrimExpr ReduceOp::MakeReduce(const PrimExpr& a, const PrimExpr& b) const { switch (type) { case ReduceType::kSum: return lhs + rhs; + case ReduceType::kAbsSum: + return lhs + Max(rhs, -rhs); case ReduceType::kMax: return Max(lhs, rhs); case ReduceType::kMin: @@ -88,6 +94,8 @@ std::string ReduceOp::MakeCodegenReducer() const { switch (type) { case ReduceType::kSum: return "tl::SumOp"; + case ReduceType::kAbsSum: + return "tl::SumOp"; case ReduceType::kMax: return "tl::MaxOp"; case ReduceType::kMin: diff --git a/src/tl/op/reduce.h b/src/tl/op/reduce.h index 8e1a0bd6bb5a..3fc4552dc5ca 100644 --- a/src/tl/op/reduce.h +++ b/src/tl/op/reduce.h @@ -45,6 +45,7 @@ class ReduceOp : public Operator { int dim; enum class ReduceType { kSum, + kAbsSum, kMax, kMin, } type; diff --git a/src/tl/tl_templates/gemm_sm90.h b/src/tl/tl_templates/gemm_sm90.h index 372a947579c7..a5620389b9b0 100644 --- a/src/tl/tl_templates/gemm_sm90.h +++ b/src/tl/tl_templates/gemm_sm90.h @@ -67,6 +67,7 @@ class GemmTensorOp { static_assert(num_warp_n == 1); static_assert(num_warp_m % 4 == 0); + template static CUTE_DEVICE void body(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) { const int tid = threadIdx.x; Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast(pA)), SmemLayoutA{}); @@ -88,18 +89,33 @@ class GemmTensorOp { partition_shape_C(tiled_mma, Shape, Int>{})); warpgroup_fence_operand(acc); - warpgroup_arrive(); - - gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc); + if(k_block == 0) { + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } warpgroup_commit_batch(); - warpgroup_wait<0>(); + if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(acc); + // warpgroup_fence_operand(acc); + // warpgroup_arrive(); + + // gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc); + + // warpgroup_commit_batch(); + // if constexpr (wg_wait >= 0) { warpgroup_wait(); } + // warpgroup_fence_operand(acc); } + template static CUTE_DEVICE void body_rs(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) { // TODO: Move bar.sync out of body_rs - asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n * 32)); + // asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n * 32)); const int tid = threadIdx.x; Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), SmemLayoutB{}); auto tiled_mma = @@ -116,29 +132,31 @@ class GemmTensorOp { Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); - // warpgroup_fence_operand(tCrA); - // warpgroup_fence_operand(acc); - // for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // warpgroup_arrive(); - // // (V,M) x (V,N) => (V,M,N) - // gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc); - // if(k_block == 0) { - // tiled_mma.accumulate_ = GMMA::ScaleOut::One; - // } - // warpgroup_commit_batch(); - // } - // warpgroup_wait<0>(); - // warpgroup_fence_operand(acc); - // warpgroup_fence_operand(tCrA); - + warpgroup_fence_operand(tCrA); warpgroup_fence_operand(acc); - warpgroup_arrive(); - - gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc); - + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc); + if(k_block == 0) { + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } warpgroup_commit_batch(); - warpgroup_wait<0>(); + if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(acc); + warpgroup_fence_operand(tCrA); + + // warpgroup_fence_operand(acc); + // warpgroup_arrive(); + + // gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc); + + // warpgroup_commit_batch(); + + // if constexpr (wg_wait >= 0) { warpgroup_wait(); } + // warpgroup_fence_operand(acc); } }; @@ -146,20 +164,24 @@ class GemmTensorOp { namespace tl { -template TL_DEVICE void gemm_ss(A_type* pA, B_type* pB, C_type* accum) { using MMA = cute::GemmTensorOp; - MMA::body(pA, pB, accum); + MMA::body(pA, pB, accum); } -template TL_DEVICE void gemm_rs(A_type* pA, B_type* pB, C_type* accum) { using MMA = cute::GemmTensorOp; - MMA::body_rs(pA, pB, accum); + MMA::body_rs(pA, pB, accum); } +template +TL_DEVICE void wait_wgmma() { + warpgroup_wait(); +} } // namespace tl diff --git a/tl_scripts/retnet_example.py b/tl_scripts/retnet_example.py index 198eda87a185..7150f3071b5f 100644 --- a/tl_scripts/retnet_example.py +++ b/tl_scripts/retnet_example.py @@ -5,8 +5,7 @@ from functools import partial -def retnet(batch, heads, seq_len, dim, is_casual, block_M, block_N): - scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) +def retnet(batch, heads, seq_len, dim, block_M, block_N): shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" @@ -16,88 +15,96 @@ def main( Q: T.Buffer(shape, dtype), K: T.Buffer(shape, dtype), V: T.Buffer(shape, dtype), + mask: T.Buffer([heads, seq_len, seq_len], dtype), Output: T.Buffer(shape, dtype), ): - with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128 * 1) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) - Q_local = T.alloc_fragment([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) + mask_shared = T.alloc_shared([block_M, block_N], dtype) + acc_o_shared = T.alloc_shared([block_M, dim], dtype) + mask_local = T.alloc_fragment([block_M, block_N], dtype) acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) acc_o = T.alloc_fragment([block_M, dim], accum_dtype) - scores_max = T.alloc_fragment([block_M], accum_dtype) - scores_max_prev = T.alloc_fragment([block_M], accum_dtype) - scores_sum = T.alloc_fragment([block_M], accum_dtype) - lse = T.alloc_fragment([block_M], accum_dtype) + abs_sum = T.alloc_fragment([block_M], accum_dtype) + r_wo_clamp = T.alloc_fragment([block_M], accum_dtype) + r = T.alloc_fragment([block_M], accum_dtype) + r_new = T.alloc_fragment([block_M], accum_dtype) T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + + T.fill(r, 0) + T.fill(r_new, 0) + T.fill(r_wo_clamp, 0) T.fill(acc_o, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.fill(scores_max_prev, -T.infinity(accum_dtype)) - T.fill(lse, -T.infinity(accum_dtype)) - T.copy(Q_shared, Q_local) loop_range = T.ceildiv(seq_len, block_N) - for k in T.Pipelined(loop_range, num_stages=1): + for k in T.Pipelined(loop_range, + num_stages=1, + order=[-1,0,-1,-1,1,2], + stage=[-1,0,-1,-1,0,0], + group=[[0],[1,2],[3],[4],[5,6,7,8,9,10,11,12,13], [14]] + ): T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) T.clear(acc_s) - T.gemm(Q_local, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_M): - scores_max[i] = T.max(scores_max_prev[i], scores_max[i]) + T.copy(mask[by, bx * block_M : (bx + 1) * block_M, k * block_N : (k + 1) * block_N], mask_shared) + T.copy(mask_shared, mask_local) for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.exp(acc_s[i, j] - scores_max[i]) - T.reduce_sum(acc_s, scores_sum, dim=1) + acc_s[i, j] = acc_s[i, j] * mask_local[i, j] + T.reduce_abssum(acc_s, abs_sum, dim=1) for i in T.Parallel(block_M): - lse[i] = scores_max[i] + T.log(T.exp(lse[i] - scores_max[i]) + scores_sum[i]) + r_wo_clamp[i] = r_wo_clamp[i] + abs_sum[i] + for i in T.Parallel(block_M): + r_new[i] = T.max(r_wo_clamp[i], 1) for i, j in T.Parallel(block_M, dim): - acc_o[i, j] = acc_o[i, j] * T.exp(scores_max_prev[i] - scores_max[i]) - T.copy(scores_max, scores_max_prev) + acc_o[i, j] = T.if_then_else(k > 0, acc_o[i, j] * r[i] / r_new[i], acc_o[i, j]) + T.copy(r_new, r) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = acc_s[i, j] / r_new[i] T.copy(acc_s, acc_s_cast) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= T.exp(scores_max[i] - lse[i]) - T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + T.copy(acc_o, acc_o_shared) + T.copy(acc_o_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) return main -def ref_program(Q, K, V, casual): - qk = torch.matmul(Q.permute(0, 2, 1, 3), K.permute(0, 2, 3, 1)) # [B, H, SEQLEN, SEQLEN] - m = qk.max(dim=-1, keepdim=True).values - p = torch.exp(qk - m) - s = p / p.sum(dim=-1, keepdim=True) - o = torch.matmul(s.to(torch.float16), V.permute(0, 2, 1, 3)) # [B, H, SEQLEN, dim] - return o.permute(0, 2, 1, 3) +def ref_program(Q, K, V, mask): + Q = Q.to(dtype=float) + K = K.to(dtype=float) + V = V.to(dtype=float) + mask = mask.to(dtype=float) + qk = torch.einsum('bqhd,bkhd->bhqk', Q, K) + qkm = qk * mask + r = qkm.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1.0) + o = torch.einsum('bhqk,bkhd->bqhd', qkm/r, V) + return o.to(dtype=torch.float16) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='Batch size') - parser.add_argument('--h', type=int, default=12, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=2048, help='Context size') - parser.add_argument('--d_head', type=int, default=256, help='Head dimension') - parser.add_argument('--casual', type=bool, default=True, help='Casual flag') + parser.add_argument('--batch', type=int, default=1, help='Batch size') + parser.add_argument('--h', type=int, default=32, help='Number of heads') + parser.add_argument('--n_ctx', type=int, default=4096, help='Context size') + parser.add_argument('--d_head', type=int, default=128, help='Head dimension') args = parser.parse_args() BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head - casual = args.casual flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD total_flops = 2 * flops_per_matmul - if casual: - total_flops *= 0.5 BLOCK_M = 64 - BLOCK_N = 64 if D_HEAD <= 128 else 32 - program = retnet(BATCH, H, N_CTX, D_HEAD, casual, BLOCK_M, BLOCK_N) - ref_program = partial(ref_program, casual=casual) + BLOCK_N = 64 + program = retnet(BATCH, H, N_CTX, D_HEAD, BLOCK_M, BLOCK_N) mod, params = tl.lower(program) - mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) - mod.assert_allclose(ref_program, rtol=0.1, atol=0.1) + mod = tl.Profiler(mod, params, [4], tl.TensorSupplyType.Normal) + mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) - # latency = mod.do_bench(ref_program, n_warmup=10, n_repeat=1) - # print("torch: {:.2f} ms".format(latency)) - # print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) - latency = mod.do_bench(mod, n_warmup=10, n_repeat=5) + latency = mod.do_bench(ref_program, n_warmup=10, n_repeat=1) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = mod.do_bench(mod, n_warmup=10, n_repeat=10) print("tl: {:.2f} ms".format(latency)) print("tl: {:.2f} TFlops".format(total_flops / latency * 1e-9)) diff --git a/tl_scripts/torch_ref.py b/tl_scripts/torch_ref.py index 704a934d0be9..39890b36987e 100644 --- a/tl_scripts/torch_ref.py +++ b/tl_scripts/torch_ref.py @@ -53,80 +53,200 @@ def set_seed(seed): # are_close = torch.allclose(ref_output, test_output, rtol=1e-03, atol=1e-03) # print(f"Are the outputs close? {are_close}") -import torch.nn.functional as F +# import torch.nn.functional as F -batch = 1 -seq_len = 1024 -heads = 1 -dim = 64 -shape = [batch, seq_len, heads, dim] -Q = torch.randn(shape, device="cuda", dtype=torch.float16) -K = torch.randn(shape, device="cuda", dtype=torch.float16) -V = torch.randn(shape, device="cuda", dtype=torch.float16) +# batch = 1 +# seq_len = 128 +# heads = 1 +# dim = 64 +# shape = [batch, seq_len, heads, dim] +# # Q = torch.randn(shape, device="cuda", dtype=torch.float16) +# # K = torch.randn(shape, device="cuda", dtype=torch.float16) +# # V = torch.randn(shape, device="cuda", dtype=torch.float16) # Q = torch.ones(shape, device="cuda", dtype=torch.float16) # K = torch.ones(shape, device="cuda", dtype=torch.float16) # V = torch.ones(shape, device="cuda", dtype=torch.float16) -def test_program(Q, K, V): - scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) - block_M = seq_len - block_N = 64 - acc_s = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float) - acc_s_cast = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float16) - acc_o = torch.empty((batch, block_M, heads, dim), device="cuda", dtype=torch.float) - scores_max = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) - scores_max_prev = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) - scores_scale = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) - scores_sum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) - logsum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) - acc_o.fill_(0) - logsum.fill_(0) - scores_max.fill_(float('-inf')) - Q *= scale - - for i in range(int(seq_len / block_N)): - acc_s.fill_(0) - acc_s = torch.einsum('bqhd,bkhd->bhqk', Q, K[:, i * block_N : (i + 1) * block_N, :, :]) # [batch, seqlen, heads, block_N] - scores_max_prev = scores_max - scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] - scores_scale = torch.exp2(scores_max_prev - scores_max) - acc_o *= scores_scale[:, :, :, None].transpose(1, 2) - acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) - # print("acc_s:", acc_s) - acc_s_cast = acc_s.to(torch.float16) - acc_o += torch.einsum('bhqk,bkhd->bqhd', acc_s_cast, V[:, i * block_N : (i + 1) * block_N, :, :]) - scores_sum = acc_s.sum(dim=-1, keepdim=False) - logsum = logsum * scores_scale + scores_sum - # print("acc_o:", acc_o.size()) - # print("logsum:", logsum.size()) - acc_o /= logsum[:, :, :, None].transpose(1, 2) - return acc_o.to(torch.float16) - - -def ref_program(Q, K, V): - dim = Q.size(-1) - scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) +# # def test_program(Q, K, V): +# # scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) +# # block_M = seq_len +# # block_N = 64 +# # acc_s = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float) +# # acc_s_cast = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float16) +# # acc_o = torch.empty((batch, block_M, heads, dim), device="cuda", dtype=torch.float) +# # scores_max = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) +# # scores_max_prev = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) +# # scores_scale = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) +# # scores_sum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) +# # logsum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) +# # acc_o.fill_(0) +# # logsum.fill_(0) +# # scores_max.fill_(float('-inf')) +# # Q *= scale + +# # for i in range(int(seq_len / block_N)): +# # acc_s.fill_(0) +# # acc_s = torch.einsum('bqhd,bkhd->bhqk', Q, K[:, i * block_N : (i + 1) * block_N, :, :]) # [batch, seqlen, heads, block_N] +# # scores_max_prev = scores_max +# # scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] +# # scores_scale = torch.exp2(scores_max_prev - scores_max) +# # acc_o *= scores_scale[:, :, :, None].transpose(1, 2) +# # acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) +# # # print("acc_s:", acc_s) +# # acc_s_cast = acc_s.to(torch.float16) +# # acc_o += torch.einsum('bhqk,bkhd->bqhd', acc_s_cast, V[:, i * block_N : (i + 1) * block_N, :, :]) +# # scores_sum = acc_s.sum(dim=-1, keepdim=False) +# # logsum = logsum * scores_scale + scores_sum +# # # print("acc_o:", acc_o.size()) +# # # print("logsum:", logsum.size()) +# # acc_o /= logsum[:, :, :, None].transpose(1, 2) +# # return acc_o.to(torch.float16) + +# def test_program(Q, K, V): +# scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) +# block_M = seq_len +# block_N = 64 +# acc_s = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float) +# acc_s_cast = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float16) +# acc_o = torch.empty((batch, block_M, heads, dim), device="cuda", dtype=torch.float) +# scores_max = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) +# scores_max_prev = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) +# scores_scale = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) +# scores_sum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) +# # logsum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) +# acc_o.fill_(0) +# scores_scale.fill_(0) +# # logsum.fill_(0) +# scores_max.fill_(float('-inf')) + +# for i in range(int(seq_len / block_N)): +# acc_s.fill_(0) +# acc_s = torch.einsum('bqhd,bkhd->bhqk', Q, K[:, i * block_N : (i + 1) * block_N, :, :]) + +# if (i == 0): +# scores_max = acc_s.max(dim=-1, keepdim=False).values +# acc_s = torch.exp2(acc_s * scale - scores_max[:, :, :, None] * scale) +# scores_sum = acc_s.sum(dim=-1, keepdim=False) +# scores_scale.fill_(1) + +# # cute::copy(softmax.template max(tSrS, mainloop_params.softmax_scale_log2), scores_scale); +# if (i > 0): +# scores_max_prev = scores_max +# scores_max = acc_s.max(dim=-1, keepdim=False).values +# print("scores_max_prev:", scores_max_prev) +# print("scores_max:", scores_max) +# print("scores_max_prev - scores_max:", scores_max_prev - scores_max) +# scores_scale = torch.exp2((scores_max_prev - scores_max) * scale) +# scores_sum *= scores_scale + +# # online_softmax +# if (i > 0): +# acc_s = torch.exp2(acc_s * scale - scores_max[:, :, :, None] * scale) +# scores_sum = acc_s.sum(dim=-1, keepdim=False) + +# # rescale_o +# if (i > 0): +# acc_o *= scores_scale[:, :, :, None].transpose(1, 2) # softmax.rescale_o(tOrO, scores_scale); + +# acc_s_cast = acc_s.to(torch.float16) +# print("acc_s_cast:", acc_s_cast) +# acc_o += torch.einsum('bhqk,bkhd->bqhd', acc_s_cast, V[:, i * block_N : (i + 1) * block_N, :, :]) +# print("acc_o:", acc_o) + +# if (i == int(seq_len / block_N) - 1): +# acc_o /= scores_sum[:, :, :, None].transpose(1, 2) + +# print("scores_scale:", scores_scale) +# print("scores_sum:", scores_sum) +# # scores_sum = scores_max * scale * 0.69314718055994530942 + torch.log(scores_sum) +# # acc_o /= scores_sum[:, :, :, None].transpose(1, 2) +# acc_o *= scores_scale[:, :, :, None].transpose(1, 2) # softmax.rescale_o(tOrO, scores_scale); +# return acc_o.to(torch.float16) + + +# # flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); +# # First cute::copy(softmax.template max(tSrS, mainloop_params.softmax_scale_log2), scores_scale); +# # cute::copy(softmax.template max(tSrS, mainloop_params.softmax_scale_log2), scores_scale); +# # softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); +# # flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); +# # softmax.rescale_o(tOrO, scores_scale); +# # finalize + +# def ref_program(Q, K, V): +# dim = Q.size(-1) +# scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) - # Step 2: Scale the scores by the square root of dim - scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) +# # Step 2: Scale the scores by the square root of dim +# scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) - # Step 3: Apply softmax to get the attention weights - attention_weights = F.softmax(scores, dim=-1) +# # Step 3: Apply softmax to get the attention weights +# attention_weights = F.softmax(scores, dim=-1) - # print("scores:", attention_weights) - # Step 4: Multiply the attention weights by the values (V) - # This gives us the final output of shape [batch, seq_len, heads, dim] - output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) +# # print("scores:", attention_weights) +# # Step 4: Multiply the attention weights by the values (V) +# # This gives us the final output of shape [batch, seq_len, heads, dim] +# output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) - return output +# return output + +# ref_output = ref_program(Q, K, V) +# test_output = test_program(Q, K, V) +# are_close = torch.allclose(ref_output, test_output, rtol=1e-03, atol=1e-03) +# print(f"Are the outputs close? {are_close}") + +# print("ref_output:", ref_output) +# print("test_output:", test_output) -ref_output = ref_program(Q, K, V) -test_output = test_program(Q, K, V) -are_close = torch.allclose(ref_output, test_output, rtol=1e-03, atol=1e-03) -print(f"Are the outputs close? {are_close}") -print("ref_output:", ref_output) -print("test_output:", test_output) \ No newline at end of file +# RetNet +batch = 8 +head = 16 +seq_len = 1024 +dim = 64 +blockM = seq_len +blockN = 64 +# q = torch.randn((BATCH, HEAD, SEQLEN, D), dtype=torch.float16) +# k = torch.randn((BATCH, HEAD, SEQLEN, D), dtype=torch.float16) +# v = torch.randn((BATCH, HEAD, SEQLEN, D), dtype=torch.float16) +# mask = torch.randn((HEAD, SEQLEN, SEQLEN), dtype=torch.float16) + +shape = [batch, seq_len, head, dim] +q = torch.randn(shape, dtype=torch.float16) +k = torch.randn(shape, dtype=torch.float16) +v = torch.randn(shape, dtype=torch.float16) +mask = torch.randn((head, seq_len, seq_len), dtype=torch.float16) + +def ref_program(q, k, v, mask): + qk = torch.einsum('bqhd,bkhd->bhqk', q, k) + qkm = qk * mask.unsqueeze(0) + r = qkm.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1.0) + o = torch.einsum('bhqk,bkhd->bqhd', qkm/r, v) + return o + +def test_program(q, k, v, mask): + r = torch.zeros((batch, head, seq_len), dtype=torch.float16) + r_new = torch.zeros((batch, head, seq_len), dtype=torch.float16) + r_wo_clamp = torch.zeros((batch, head, seq_len), dtype=torch.float16) + acco = torch.zeros((batch, head, seq_len, dim), dtype=torch.float16) + for i in range(seq_len // blockN): + qk = torch.einsum('bqhd,bkhd->bhqk', q, k[:, i * blockN : (i + 1) * blockN, :, :]) + qkm = qk * mask[:, :, i * blockN : (i + 1) * blockN].unsqueeze(0) + r_wo_clamp += qkm.detach().abs().sum(dim=-1, keepdim=False) + r_new = torch.max(r_wo_clamp, torch.ones_like(r_wo_clamp)) + if (i != 0): + acco = acco * r.unsqueeze(-1) / r_new.unsqueeze(-1) + r = r_new + acco += torch.einsum('bhqk,bkhd->bhqd', qkm / r_new.unsqueeze(-1), v[:, i * blockN : (i + 1) * blockN, :, :]) + return acco.transpose(1, 2) + +ref_output = ref_program(q, k, v, mask) +test_output = test_program(q, k, v, mask) +are_close = torch.allclose(ref_output, test_output, rtol=1e-02, atol=1e-02) +print(f"Are the outputs close? {are_close}") +# print("ref_output:", ref_output) +# print("test_output:", test_output) +# print("ref_output shape:", ref_output.size()) +# print("test_output shape:", test_output.size()) \ No newline at end of file From cb95f23447036e9e84d0c7f641fa00b4d4587a24 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Thu, 26 Sep 2024 09:52:10 +0000 Subject: [PATCH 18/23] [tl] Update --- .gitmodules | 6 + 3rdparty/fa3 | 1 + 3rdparty/flash-linear-attention | 1 + python/tvm/contrib/nvcc.py | 5 +- python/tvm/tl/modified_code.cu | 191 ++++++++++++++++++++++++++++ src/tl/layout/gemm_layouts.cc | 8 +- src/tl/op/gemm.cc | 11 +- src/tl/target/codegen.cc | 4 + src/tl/tl_templates/gemm_sm90.h | 10 +- src/tl/tl_templates/reduce.h | 2 + tl_scripts/gemm_hopper.py | 51 ++++++++ tl_scripts/mha_pipeline.py | 30 +++-- tl_scripts/mha_pipeline_search.py | 202 ++++++++++++++++++++++++++++++ tl_scripts/retnet_example.py | 73 ++++++----- tl_scripts/torch_ref.py | 14 ++- 15 files changed, 548 insertions(+), 61 deletions(-) create mode 160000 3rdparty/fa3 create mode 160000 3rdparty/flash-linear-attention create mode 100644 python/tvm/tl/modified_code.cu create mode 100644 tl_scripts/gemm_hopper.py create mode 100644 tl_scripts/mha_pipeline_search.py diff --git a/.gitmodules b/.gitmodules index d2cb394d8997..e9fefa469390 100644 --- a/.gitmodules +++ b/.gitmodules @@ -31,3 +31,9 @@ [submodule "cutlass"] path = cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "3rdparty/fa3"] + path = 3rdparty/fa3 + url = git@github.com:Dao-AILab/flash-attention.git +[submodule "3rdparty/flash-linear-attention"] + path = 3rdparty/flash-linear-attention + url = git@github.com:sustcsonglin/flash-linear-attention.git diff --git a/3rdparty/fa3 b/3rdparty/fa3 new file mode 160000 index 000000000000..74b0761ff7ef --- /dev/null +++ b/3rdparty/fa3 @@ -0,0 +1 @@ +Subproject commit 74b0761ff7efc7b90d4e5aeb529c1b2a09a7458c diff --git a/3rdparty/flash-linear-attention b/3rdparty/flash-linear-attention new file mode 160000 index 000000000000..33b89d5415e9 --- /dev/null +++ b/3rdparty/flash-linear-attention @@ -0,0 +1 @@ +Subproject commit 33b89d5415e951718ccdea74e695ad807ddddf96 diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 5eb348009914..7f3b1adce4e5 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -29,7 +29,7 @@ from . import utils -def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None): +def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None, get_output=False): """Compile cuda code with NVCC from env. Parameters @@ -106,6 +106,9 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target (out, _) = proc.communicate() + if get_output: + print(py_str(out)) + if proc.returncode != 0: msg = code msg += "\nCompilation error:\n" diff --git a/python/tvm/tl/modified_code.cu b/python/tvm/tl/modified_code.cu new file mode 100644 index 000000000000..f8d90d99de36 --- /dev/null +++ b/python/tvm/tl/modified_code.cu @@ -0,0 +1,191 @@ +#include +#include +#include +#include +#include + +extern "C" __global__ void __launch_bounds__(384) main_kernel(__grid_constant__ const CUtensorMap K_desc, __grid_constant__ const CUtensorMap Output_desc, __grid_constant__ const CUtensorMap Q_desc, __grid_constant__ const CUtensorMap V_desc, __grid_constant__ const CUtensorMap mask_desc) { + extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; + float r[2]; + float r_new[2]; + float r_wo_clamp[2]; + float acc_o[112]; + float acc_s[16]; + half_t mask_local[16]; + float acc_s_1[32]; + float abs_sum[2]; + half_t acc_s_cast[32]; + __shared__ uint64_t _mbarrier[11]; + if (((int)threadIdx.x) == 0) { + tl::prefetch_tma_descriptor(Q_desc); + tl::prefetch_tma_descriptor(K_desc); + tl::prefetch_tma_descriptor(mask_desc); + tl::prefetch_tma_descriptor(V_desc); + tl::prefetch_tma_descriptor(Output_desc); + tl::mbarrier_init(_mbarrier[0], 128); + tl::mbarrier_init(_mbarrier[1], 128); + tl::mbarrier_init(_mbarrier[2], 128); + tl::mbarrier_init(_mbarrier[3], 256); + tl::mbarrier_init(_mbarrier[4], 256); + tl::mbarrier_init(_mbarrier[5], 256); + tl::mbarrier_init(_mbarrier[6], 128); + tl::mbarrier_init(_mbarrier[7], 256); + tl::mbarrier_init(_mbarrier[8], 256); + tl::mbarrier_init(_mbarrier[9], 256); + tl::mbarrier_init(_mbarrier[10], 256); + } + __syncthreads(); + if (256 <= ((int)threadIdx.x)) { + tl::warpgroup_reg_dealloc<24>(); + // if (((int)threadIdx.x) == 256) { + // tl::mbarrier_expect_tx(_mbarrier[6], 32768); + // } + // if (((int)threadIdx.x) == 256) { + // tl::tma_load(Q_desc, _mbarrier[6], (&(((half_t*)buf_dyn_shmem)[12288])), 0, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); + // tl::tma_load(Q_desc, _mbarrier[6], (&(((half_t*)buf_dyn_shmem)[16384])), 64, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); + // tl::tma_load(Q_desc, _mbarrier[6], (&(((half_t*)buf_dyn_shmem)[20480])), 128, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); + // tl::tma_load(Q_desc, _mbarrier[6], (&(((half_t*)buf_dyn_shmem)[24576])), 192, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); + // } + // tl::mbarrier_arrive(_mbarrier[6]); + // for (int k = 0; k < 128; ++k) { + // tl::mbarrier_wait(_mbarrier[3], ((k & 1) ^ 1)); + // if (((int)threadIdx.x) == 256) { + // tl::mbarrier_expect_tx(_mbarrier[0], 32768); + // } + // if (((int)threadIdx.x) == 256) { + // tl::tma_load(K_desc, _mbarrier[0], (&(((half_t*)buf_dyn_shmem)[28672])), 0, ((int)blockIdx.y), (k * 64), 0); + // tl::tma_load(K_desc, _mbarrier[0], (&(((half_t*)buf_dyn_shmem)[32768])), 64, ((int)blockIdx.y), (k * 64), 0); + // tl::tma_load(K_desc, _mbarrier[0], (&(((half_t*)buf_dyn_shmem)[36864])), 128, ((int)blockIdx.y), (k * 64), 0); + // tl::tma_load(K_desc, _mbarrier[0], (&(((half_t*)buf_dyn_shmem)[40960])), 192, ((int)blockIdx.y), (k * 64), 0); + // } + // tl::mbarrier_arrive(_mbarrier[0]); + // tl::mbarrier_wait(_mbarrier[4], ((k & 1) ^ 1)); + // if (((int)threadIdx.x) == 256) { + // tl::mbarrier_expect_tx(_mbarrier[1], 8192); + // } + // if (((int)threadIdx.x) == 256) { + // tl::tma_load(mask_desc, _mbarrier[1], (&(((half_t*)buf_dyn_shmem)[0])), (k * 64), (((int)blockIdx.x) * 64), ((int)blockIdx.y)); + // } + // tl::mbarrier_arrive(_mbarrier[1]); + // tl::mbarrier_wait(_mbarrier[5], ((k & 1) ^ 1)); + // if (((int)threadIdx.x) == 256) { + // tl::mbarrier_expect_tx(_mbarrier[2], 57344); + // } + // if (((int)threadIdx.x) == 256) { + // tl::tma_load(V_desc, _mbarrier[2], (&(((half_t*)buf_dyn_shmem)[45056])), 0, ((int)blockIdx.y), (k * 64), 0); + // tl::tma_load(V_desc, _mbarrier[2], (&(((half_t*)buf_dyn_shmem)[49152])), 64, ((int)blockIdx.y), (k * 64), 0); + // tl::tma_load(V_desc, _mbarrier[2], (&(((half_t*)buf_dyn_shmem)[53248])), 128, ((int)blockIdx.y), (k * 64), 0); + // tl::tma_load(V_desc, _mbarrier[2], (&(((half_t*)buf_dyn_shmem)[57344])), 192, ((int)blockIdx.y), (k * 64), 0); + // tl::tma_load(V_desc, _mbarrier[2], (&(((half_t*)buf_dyn_shmem)[61440])), 256, ((int)blockIdx.y), (k * 64), 0); + // tl::tma_load(V_desc, _mbarrier[2], (&(((half_t*)buf_dyn_shmem)[65536])), 320, ((int)blockIdx.y), (k * 64), 0); + // tl::tma_load(V_desc, _mbarrier[2], (&(((half_t*)buf_dyn_shmem)[69632])), 384, ((int)blockIdx.y), (k * 64), 0); + // } + // tl::mbarrier_arrive(_mbarrier[2]); + // } + } else { + tl::warpgroup_reg_alloc<240>(); + #pragma unroll + for (int i = 0; i < 2; ++i) { + r[i] = 0.000000e+00f; + } + #pragma unroll + for (int i_1 = 0; i_1 < 2; ++i_1) { + r_new[i_1] = 0.000000e+00f; + } + #pragma unroll + for (int i_2 = 0; i_2 < 2; ++i_2) { + r_wo_clamp[i_2] = 0.000000e+00f; + } + #pragma unroll + for (int i_3 = 0; i_3 < 112; ++i_3) { + acc_o[i_3] = 0.000000e+00f; + } + tl::fence_proxy_async(); + // tl::mbarrier_wait(_mbarrier[6], 0); + #pragma unroll 1 + for (int k_1 = 0; k_1 < 128; ++k_1) { + #pragma unroll + for (int i_4 = 0; i_4 < 16; ++i_4) { + acc_s[i_4] = 0.000000e+00f; + } + tl::fence_proxy_async(); + // tl::mbarrier_wait(_mbarrier[0], (k_1 & 1)); + tl::gemm_ss<64, 64, 256, 4, 2, 0, 1>((&(((half_t*)buf_dyn_shmem)[12288])), (&(((half_t*)buf_dyn_shmem)[28672])), (&(acc_s[0]))); + // tl::mbarrier_arrive(_mbarrier[3]); + // tl::mbarrier_wait(_mbarrier[1], (k_1 & 1)); + #pragma unroll + for (int i_5 = 0; i_5 < 2; ++i_5) { + tl::ptx_ldmatrix_x4((&(((half_t*)buf_dyn_shmem)[(((((((((int)threadIdx.x) & 127) >> 5) * 1024) + ((((int)threadIdx.x) & 15) * 64)) + ((((int)threadIdx.x) >> 7) * 32)) + (i_5 * 16)) + (((((int)threadIdx.x) & 31) >> 4) * 8))])), (&(mask_local[(i_5 * 8)]))); + } + tl::fence_proxy_async(); + // tl::mbarrier_arrive(_mbarrier[4]); + #pragma unroll + for (int i_6 = 0; i_6 < 16; ++i_6) { + acc_s[i_6] = (acc_s[i_6] * ((float)mask_local[i_6])); + } + tl::syncthreads_partial(_mbarrier[7]); + #pragma unroll + for (int i_7 = 0; i_7 < 8; ++i_7) { + ((float2*)buf_dyn_shmem)[(((((((((((int)threadIdx.x) & 127) >> 5) * 512) + ((i_7 & 1) * 256)) + (((((int)threadIdx.x) & 31) >> 2) * 32)) + ((((int)threadIdx.x) >> 7) * 16)) + ((i_7 >> 1) * 4)) + (((int)threadIdx.x) & 3)) + 1024)] = *(float2*)(acc_s + (i_7 * 2)); + } + tl::syncthreads_partial(_mbarrier[8]); + #pragma unroll + for (int i_8 = 0; i_8 < 16; ++i_8) { + *(float2*)(acc_s_1 + (i_8 * 2)) = ((float2*)buf_dyn_shmem)[((((((((((int)threadIdx.x) & 127) >> 5) * 512) + ((i_8 & 1) * 256)) + (((((int)threadIdx.x) & 31) >> 2) * 32)) + ((i_8 >> 1) * 4)) + (((int)threadIdx.x) & 3)) + 1024)]; + } + #pragma unroll + for (int i_9 = 0; i_9 < 2; ++i_9) { + abs_sum[i_9] = 0.000000e+00f; + #pragma unroll + for (int rv = 0; rv < 16; ++rv) { + abs_sum[i_9] = (abs_sum[i_9] + max(acc_s_1[((((rv & 7) * 4) + (i_9 * 2)) + (rv >> 3))], (0.000000e+00f - acc_s_1[((((rv & 7) * 4) + (i_9 * 2)) + (rv >> 3))]))); + } + abs_sum[i_9] = tl::AllReduce::run(abs_sum[i_9]); + } + #pragma unroll + for (int i_10 = 0; i_10 < 2; ++i_10) { + r_wo_clamp[i_10] = (r_wo_clamp[i_10] + abs_sum[i_10]); + } + #pragma unroll + for (int i_11 = 0; i_11 < 2; ++i_11) { + r_new[i_11] = max(r_wo_clamp[i_11], 1.000000e+00f); + } + #pragma unroll + for (int i_12 = 0; i_12 < 112; ++i_12) { + acc_o[i_12] = ((0 < k_1) ? ((acc_o[i_12] * r[((i_12 & 3) >> 1)]) / r_new[((i_12 & 3) >> 1)]) : acc_o[i_12]); + } + #pragma unroll + for (int i_13 = 0; i_13 < 2; ++i_13) { + r[i_13] = r_new[i_13]; + } + #pragma unroll + for (int i_14 = 0; i_14 < 32; ++i_14) { + acc_s_1[i_14] = (acc_s_1[i_14] / r_new[((i_14 & 3) >> 1)]); + } + #pragma unroll + for (int i_15 = 0; i_15 < 32; ++i_15) { + acc_s_cast[i_15] = ((half_t)acc_s_1[i_15]); + } + tl::fence_proxy_async(); + // tl::mbarrier_wait(_mbarrier[2], (k_1 & 1)); + tl::gemm_rs<64, 448, 64, 4, 2, 0, 0>((&(acc_s_cast[0])), (&(((half_t*)buf_dyn_shmem)[45056])), (&(acc_o[0]))); + // tl::mbarrier_arrive(_mbarrier[5]); + } + tl::syncthreads_partial(_mbarrier[9]); + #pragma unroll + for (int i_16 = 0; i_16 < 14; ++i_16) { + tl::ptx_stmatrix_x4((&(((half_t*)buf_dyn_shmem)[(((((((((((((int)threadIdx.x) >> 7) * 7) + (i_16 >> 1)) >> 1) * 4096) + (((((int)threadIdx.x) & 127) >> 5) * 1024)) + ((((int)threadIdx.x) & 15) * 64)) + (((((((int)threadIdx.x) & 7) >> 2) + (((((int)threadIdx.x) >> 7) + (i_16 >> 1)) & 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 3) >> 1) + (i_16 & 1)) & 1) * 16)) + (((((((int)threadIdx.x) & 31) >> 4) + (((int)threadIdx.x) & 1)) & 1) * 8)) + 45056)])), __pack_half2(((half_t)acc_o[(i_16 * 8)]), ((half_t)acc_o[((i_16 * 8) + 1)])), __pack_half2(((half_t)acc_o[((i_16 * 8) + 2)]), ((half_t)acc_o[((i_16 * 8) + 3)])), __pack_half2(((half_t)acc_o[((i_16 * 8) + 4)]), ((half_t)acc_o[((i_16 * 8) + 5)])), __pack_half2(((half_t)acc_o[((i_16 * 8) + 6)]), ((half_t)acc_o[((i_16 * 8) + 7)]))); + } + tl::fence_proxy_async(); + tl::syncthreads_partial(_mbarrier[10]); + if (((int)threadIdx.x) == 0) { + tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[45056])), 0, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); + tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[49152])), 64, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); + tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[53248])), 128, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); + tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[57344])), 192, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); + tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[61440])), 256, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); + tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[65536])), 320, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); + tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[69632])), 384, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); + } + } +} \ No newline at end of file diff --git a/src/tl/layout/gemm_layouts.cc b/src/tl/layout/gemm_layouts.cc index f2a0110c001d..e60b7449ab76 100644 --- a/src/tl/layout/gemm_layouts.cc +++ b/src/tl/layout/gemm_layouts.cc @@ -83,12 +83,12 @@ Fragment makeGemmFragmentC(const int block_m, const int block_n, const int warp_ Fragment makeGemmFragmentCHopper(const int block_m, const int block_n, const int warp_m, const int warp_n, const int element_size) { ICHECK(block_m % warp_m == 0); - ICHECK(block_n == warp_n); + // ICHECK(block_n == warp_n); ICHECK(warp_m % 16 == 0); auto warp_layout = - makeGemmFragment8x8()->Repeat({2, block_n / 8}, false, false); // 16 x N (1 warp) - auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true); // 16*Y x N (Y warp) - return block_layout->Repeat({warp_m / 16, 1}, false); + makeGemmFragment8x8()->Repeat({2, warp_n / 8}, false, false); // 16 x N (1 warp) + auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); // 16*Y x N (Y warp) + return block_layout->Repeat({warp_m / 16, 1}, false, false); } Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block_k, diff --git a/src/tl/op/gemm.cc b/src/tl/op/gemm.cc index 6c492faec910..649b7952382b 100644 --- a/src/tl/op/gemm.cc +++ b/src/tl/op/gemm.cc @@ -67,7 +67,16 @@ std::pair Gemm::ComputeWarpPartition(int num_warps, Target target) con int m_warp = 1, n_warp = 1; if (TargetIsHopper(target)) { ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads."; - m_warp = num_warps; + if (this->policy == GemmWarpPolicy::kFullRow || this->policy == GemmWarpPolicy::kSquare) { + m_warp = num_warps; + ICHECK(this->M % num_warps == 0); + } else if (this->policy == GemmWarpPolicy::kFullCol) { + m_warp = 4; + n_warp = num_warps / 4; + ICHECK(this->N % n_warp == 0); + } else { + ICHECK(0) << "Unknown GemmWarpPolicy"; + } return {m_warp, n_warp}; } if (this->policy == GemmWarpPolicy::kFullRow) { diff --git a/src/tl/target/codegen.cc b/src/tl/target/codegen.cc index d2253820e2b1..c973e4e88a12 100644 --- a/src/tl/target/codegen.cc +++ b/src/tl/target/codegen.cc @@ -690,6 +690,10 @@ void CodeGenTL::VisitExpr_(const CallNode* op, std::ostream& os) { int is_inc = Downcast(op->args[1])->value; std::string func_name = is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc"; this->stream << func_name << "<" << std::to_string(nreg) << ">();\n"; + } else if (op->op.same_as(tl::WaitWgmma())) { + this->PrintIndent(); + int num_mma = Downcast(op->args[0])->value; + this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n"; } else if (op->op.same_as(tl::PackB16Op())) { os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1]) << ")"; diff --git a/src/tl/tl_templates/gemm_sm90.h b/src/tl/tl_templates/gemm_sm90.h index a5620389b9b0..7bd45157188e 100644 --- a/src/tl/tl_templates/gemm_sm90.h +++ b/src/tl/tl_templates/gemm_sm90.h @@ -64,7 +64,7 @@ class GemmTensorOp { using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, Shape, Int>{}, conditional_t, Step<_2, _1>>{})); - static_assert(num_warp_n == 1); + // static_assert(num_warp_n == 1); static_assert(num_warp_m % 4 == 0); template @@ -73,9 +73,9 @@ class GemmTensorOp { Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast(pA)), SmemLayoutA{}); Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), SmemLayoutB{}); auto tiled_mma = - make_tiled_mma(GMMA::ss_op_selector, Int, Int>, + make_tiled_mma(GMMA::ss_op_selector, Int, Int>, GmmaMajorA, GmmaMajorB>(), - Layout, _1, _1>>{}); + Layout, Int, _1>>{}); auto thr_mma = tiled_mma.get_thread_slice(tid); // Allocate registers for pipelining @@ -119,9 +119,9 @@ class GemmTensorOp { const int tid = threadIdx.x; Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), SmemLayoutB{}); auto tiled_mma = - make_tiled_mma(GMMA::rs_op_selector, Int, Int>, + make_tiled_mma(GMMA::rs_op_selector, Int, Int>, GmmaMajorA, GmmaMajorB>(), - Layout, _1, _1>>{}); + Layout, Int, _1>>{}); auto thr_mma = tiled_mma.get_thread_slice(tid); // Allocate registers for pipelining diff --git a/src/tl/tl_templates/reduce.h b/src/tl/tl_templates/reduce.h index 117638f9c90a..f455685c6c6b 100644 --- a/src/tl/tl_templates/reduce.h +++ b/src/tl/tl_templates/reduce.h @@ -36,8 +36,10 @@ struct AllReduce { constexpr int offset = threads / 2; if constexpr (offset >= 32) { __syncthreads(); + // asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(256)); red_buf[threadIdx.x] = x; __syncthreads(); + // asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(256)); x = Reducer()(x, red_buf[threadIdx.x ^ offset]); } else { x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset))); diff --git a/tl_scripts/gemm_hopper.py b/tl_scripts/gemm_hopper.py new file mode 100644 index 000000000000..3d317cb41d6c --- /dev/null +++ b/tl_scripts/gemm_hopper.py @@ -0,0 +1,51 @@ +import torch +from tvm import tl +import tvm.tl.language as T + +def matmul(M, N, K, block_M, block_N, block_K): + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def main(A: T.Buffer((M, K), dtype), B: T.Buffer((K, N), dtype), C: T.Buffer((M, N), dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128 * 2) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + A_local = T.alloc_fragment((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.annotate_layout({ + A_shared: tl.layout.make_swizzled_layout(A_shared) + }) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=2): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(A_shared, A_local) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_local, B_shared, C_local, policy=T.GemmWarpPolicy.FullCol) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + +M, N, K, block_M, block_N, block_K = 8192, 512 * 16, 8192, 64, 512, 64 + +def ref_program(A, B): + return A @ B + +if __name__ == "__main__": + total_flops = 2 * M * N * K + + program = matmul(M, N, K, block_M, block_N, block_K) + mod, params = tl.lower(program) + + mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Normal) + mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("All checks pass.") + + # latency = mod.do_bench(ref_program, warmup=500) + # print("{:.2f} ms".format(latency)) + # print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = mod.do_bench(mod.func, n_warmup=10, n_repeat=10, profiler="torch") + print("{:.2f} ms".format(latency)) + print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) \ No newline at end of file diff --git a/tl_scripts/mha_pipeline.py b/tl_scripts/mha_pipeline.py index be4de6e8bd6b..a66c10cec54d 100644 --- a/tl_scripts/mha_pipeline.py +++ b/tl_scripts/mha_pipeline.py @@ -79,7 +79,6 @@ def MMA1( def Softmax( acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s_cast: T.Buffer([block_M, block_N], dtype), - acc_o: T.Buffer([block_M, dim], accum_dtype), scores_max: T.Buffer([block_M], accum_dtype), scores_max_prev: T.Buffer([block_M], accum_dtype), scores_scale: T.Buffer([block_M], accum_dtype), @@ -104,9 +103,15 @@ def Softmax( T.reduce_sum(acc_s, scores_sum, dim=1) for i in T.Parallel(block_M): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.Buffer([block_M, dim], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + ): for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] - T.copy(acc_s, acc_s_cast) @T.prim_func def main( @@ -142,9 +147,10 @@ def main( T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_casual else T.ceildiv(seq_len, block_N) ) - for k in T.Pipelined(loop_range, num_stages=1, order=[0,2,1], stage=[0,0,1], group=[[0,1], [2,3,4,5,6,7,8,9,10], [11]]): + for k in T.Pipelined(loop_range, num_stages=2, order=[-1,0,3,1,-1,2], stage=[-1,0,0,1,-1,1], sync=[[0,13],[1,9]], group=[[0], [1,2], [3,4,5,6,7,8,9,10], [11], [12], [13]]): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] @@ -155,20 +161,22 @@ def main( def ref_program(Q, K, V, casual): - from flash_attn.flash_attn_interface import flash_attn_func - - return flash_attn_func(Q, K, V, causal=casual) + import sys + sys.path.append("/home/msra/cy/tvm.tl/3rdparty/fa3") + from hopper.flash_attn_interface import flash_attn_func + ret = flash_attn_func(Q, K, V, causal=casual) + return ret[0] if __name__ == "__main__": - BATCH, H, N_CTX, D_HEAD = 64, 12, 1024, 128 - casual = True + BATCH, H, N_CTX, D_HEAD = 1, 32, 4096, 128 + casual = False flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD total_flops = 2 * flops_per_matmul if casual: total_flops *= 0.5 BLOCK_M = 128 - BLOCK_N = 128 # if D_HEAD <= 128 else 32 + BLOCK_N = 176 # if D_HEAD <= 128 else 32 program = flashattn(BATCH, H, N_CTX, D_HEAD, casual, BLOCK_M, BLOCK_N) ref_program = partial(ref_program, casual=casual) mod, params = tl.lower(program) @@ -178,6 +186,6 @@ def ref_program(Q, K, V, casual): latency = mod.do_bench(ref_program, warmup=500) print("{:.2f} ms".format(latency)) print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) - latency = mod.do_bench(mod, n_warmup=10, n_repeat=10) + latency = mod.do_bench(mod, n_warmup=10, n_repeat=10, profiler="torch") print("{:.2f} ms".format(latency)) print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) \ No newline at end of file diff --git a/tl_scripts/mha_pipeline_search.py b/tl_scripts/mha_pipeline_search.py new file mode 100644 index 000000000000..18dcc0c4bd85 --- /dev/null +++ b/tl_scripts/mha_pipeline_search.py @@ -0,0 +1,202 @@ +import torch +from tvm import tl +import tvm.tl.language as T +from functools import partial +from tvm.tl.autotuner import * +import itertools + +# Codegen bug: +# LoadK should wait for MMA0 done +# @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) +# def tvm_callback_cuda_postproc(code, _): +# code = code.replace("""tl::mbarrier_wait(_mbarrier[1], ((k & 1) ^ 1));""", +# """tl::mbarrier_wait(_mbarrier[1], ((k & 1))); // replace""") +# code = code.replace("""tl::gemm_ss<64, 64, 64, 4, 1, 0, 1>((&(((half_t*)buf_dyn_shmem)[0])), (&(((half_t*)buf_dyn_shmem)[4096])), (&(acc_s[0]))); +# #pragma unroll""", +# """tl::gemm_ss<64, 64, 64, 4, 1, 0, 1>((&(((half_t*)buf_dyn_shmem)[0])), (&(((half_t*)buf_dyn_shmem)[4096])), (&(acc_s[0]))); +# tl::mbarrier_arrive(_mbarrier[1]); +# #pragma unroll // replace""") +# return code + +# loadk(0) +# gemm0(0) +# loadk(1) +# softmax(0) +# loadv(0) + +# for i in range(loop_range - 2): +# gemm0(i+1) +# gemm1(i+0) +# loadk(i+2) +# softmax(i+1) +# loadv(i+1) + +# gemm0(loop_range - 1) +# gemm1(loop_range - 2) +# softmax(loop_range - 1) +# loadv(loop_range - 1) +# gemm1(loop_range - 1) + +def ref_program(Q, K, V, casual): + # from flash_attn.flash_attn_interface import flash_attn_func + import sys + sys.path.append("/home/msra/cy/tvm.tl/3rdparty/fa3") + from hopper.flash_attn_interface import flash_attn_func + ret = flash_attn_func(Q, K, V, causal=casual) + return ret[0] + +def get_configs(): + block_M = [64, 128] + block_N = [64 + i * 16 for i in range(13)] + # block_N = [80] + num_stages = [2] + + _configs = list(itertools.product(block_M, block_N, num_stages)) + + configs = [ + {'block_M': c[0], 'block_N': c[1], 'num_stages': c[2], 'thread_num': c[0] * 2} + for c in _configs + ] + return configs + +def flashattn(batch, heads, seq_len, dim, is_casual): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = "float16" + accum_dtype = "float" + + @autotune(configs=get_configs(), keys=['block_M', 'block_N', 'num_stages', 'thread_num'], warmup=10, rep=5) + @jit(out_idx=[3], supply_type=tl.TensorSupplyType.Normal, ref_prog=partial(ref_program, casual=is_casual), rtol=0.01, atol=0.01, profiler="tvm") + def kernel(block_M = None, block_N = None, num_stages = None, thread_num = None): + @T.macro + def MMA0( + K: T.Buffer(shape, dtype), + Q_shared: T.Buffer([block_M, dim], dtype), + K_shared: T.Buffer([block_N, dim], dtype), + acc_s: T.Buffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_casual: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype) + ) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Buffer(shape, dtype), + V_shared: T.Buffer([block_M, dim], dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + acc_o: T.Buffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.Buffer([block_M, block_N], accum_dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + scores_max: T.Buffer([block_M], accum_dtype), + scores_max_prev: T.Buffer([block_M], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + scores_sum: T.Buffer([block_M], accum_dtype), + logsum: T.Buffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.Buffer([block_M, dim], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Buffer(shape, dtype), + K: T.Buffer(shape, dtype), + V: T.Buffer(shape, dtype), + Output: T.Buffer(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + # loop_range = ( + # T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N) + # ) + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_casual else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages, order=[-1,0,3,1,-1,2], stage=[-1,0,0,1,-1,1], sync=[[0,13],[1,9]], group=[[0], [1,2], [3,4,5,6,7,8,9,10], [11], [12], [13]]): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + + return main + return kernel() + + +if __name__ == "__main__": + BATCH, H, N_CTX, D_HEAD = 1, 6, 1024, 64 + casual = False + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 2 * flops_per_matmul + if casual: + total_flops *= 0.5 + best_latency, best_config, ref_latency = flashattn(BATCH, H, N_CTX, D_HEAD, casual) + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}") \ No newline at end of file diff --git a/tl_scripts/retnet_example.py b/tl_scripts/retnet_example.py index 7150f3071b5f..2782078069db 100644 --- a/tl_scripts/retnet_example.py +++ b/tl_scripts/retnet_example.py @@ -5,35 +5,44 @@ from functools import partial -def retnet(batch, heads, seq_len, dim, block_M, block_N): - shape = [batch, seq_len, heads, dim] +def retnet(batch, heads, seq_len, dim_qk, dim_v, block_M, block_N): + qk_shape = [batch, seq_len, heads, dim_qk] + v_shape = [batch, seq_len, heads, dim_v] dtype = "float16" accum_dtype = "float" @T.prim_func def main( - Q: T.Buffer(shape, dtype), - K: T.Buffer(shape, dtype), - V: T.Buffer(shape, dtype), + Q: T.Buffer(qk_shape, dtype), + K: T.Buffer(qk_shape, dtype), + V: T.Buffer(v_shape, dtype), mask: T.Buffer([heads, seq_len, seq_len], dtype), - Output: T.Buffer(shape, dtype), + Output: T.Buffer(v_shape, dtype), ): - with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128 * 1) as (bx, by, bz): - Q_shared = T.alloc_shared([block_M, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim], dtype) + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128 * 2) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim_qk], dtype) + K_shared = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) mask_shared = T.alloc_shared([block_M, block_N], dtype) - acc_o_shared = T.alloc_shared([block_M, dim], dtype) + acc_o_shared = T.alloc_shared([block_M, dim_v], dtype) mask_local = T.alloc_fragment([block_M, block_N], dtype) acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_1 = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_shared = T.alloc_shared([block_M, block_N], dtype) acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) - acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype) abs_sum = T.alloc_fragment([block_M], accum_dtype) r_wo_clamp = T.alloc_fragment([block_M], accum_dtype) r = T.alloc_fragment([block_M], accum_dtype) r_new = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + T.annotate_layout({ + Q_shared: tl.layout.make_swizzled_layout(Q_shared), + mask_shared: tl.layout.make_swizzled_layout(mask_shared), + acc_s_shared: tl.layout.make_swizzled_layout(acc_s_shared), + acc_o_shared: tl.layout.make_swizzled_layout(acc_o_shared) + }) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) T.fill(r, 0) @@ -43,30 +52,32 @@ def main( loop_range = T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1, - order=[-1,0,-1,-1,1,2], - stage=[-1,0,-1,-1,0,0], - group=[[0],[1,2],[3],[4],[5,6,7,8,9,10,11,12,13], [14]] + order=[-1,0,-1,1,-1,2], + stage=[-1,0,-1,0,-1,0], + group=[[0],[1,2],[3],[4,5,6,7,8,9,10,11,12,13,14],[15],[16]] ): T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.copy(mask[by, bx * block_M : (bx + 1) * block_M, k * block_N : (k + 1) * block_N], mask_shared) T.copy(mask_shared, mask_local) for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = acc_s[i, j] * mask_local[i, j] - T.reduce_abssum(acc_s, abs_sum, dim=1) + T.copy(acc_s, acc_s_shared) + T.copy(acc_s_shared, acc_s_1) + T.reduce_abssum(acc_s_1, abs_sum, dim=1) for i in T.Parallel(block_M): r_wo_clamp[i] = r_wo_clamp[i] + abs_sum[i] for i in T.Parallel(block_M): r_new[i] = T.max(r_wo_clamp[i], 1) - for i, j in T.Parallel(block_M, dim): + for i, j in T.Parallel(block_M, dim_v): acc_o[i, j] = T.if_then_else(k > 0, acc_o[i, j] * r[i] / r_new[i], acc_o[i, j]) T.copy(r_new, r) for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = acc_s[i, j] / r_new[i] - T.copy(acc_s, acc_s_cast) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + acc_s_1[i, j] = acc_s_1[i, j] / r_new[i] + T.copy(acc_s_1, acc_s_cast) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) T.copy(acc_o, acc_o_shared) T.copy(acc_o_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) @@ -74,10 +85,6 @@ def main( def ref_program(Q, K, V, mask): - Q = Q.to(dtype=float) - K = K.to(dtype=float) - V = V.to(dtype=float) - mask = mask.to(dtype=float) qk = torch.einsum('bqhd,bkhd->bhqk', Q, K) qkm = qk * mask r = qkm.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1.0) @@ -90,14 +97,14 @@ def ref_program(Q, K, V, mask): parser.add_argument('--batch', type=int, default=1, help='Batch size') parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument('--n_ctx', type=int, default=4096, help='Context size') - parser.add_argument('--d_head', type=int, default=128, help='Head dimension') + parser.add_argument('--dim_qk', type=int, default=256, help='Head dimension') + parser.add_argument('--dim_v', type=int, default=448, help='Head dimension') args = parser.parse_args() - BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head - flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD - total_flops = 2 * flops_per_matmul + BATCH, H, N_CTX, dim_qk, dim_v = args.batch, args.h, args.n_ctx, args.dim_qk, args.dim_v + total_flops = 2.0 * BATCH * H * N_CTX * N_CTX * (dim_qk + dim_v) BLOCK_M = 64 BLOCK_N = 64 - program = retnet(BATCH, H, N_CTX, D_HEAD, BLOCK_M, BLOCK_N) + program = retnet(BATCH, H, N_CTX, dim_qk, dim_v, BLOCK_M, BLOCK_N) mod, params = tl.lower(program) mod = tl.Profiler(mod, params, [4], tl.TensorSupplyType.Normal) mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) @@ -105,6 +112,6 @@ def ref_program(Q, K, V, mask): latency = mod.do_bench(ref_program, n_warmup=10, n_repeat=1) print("torch: {:.2f} ms".format(latency)) print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) - latency = mod.do_bench(mod, n_warmup=10, n_repeat=10) + latency = mod.do_bench(mod, n_warmup=10, n_repeat=10, profiler="torch") print("tl: {:.2f} ms".format(latency)) print("tl: {:.2f} TFlops".format(total_flops / latency * 1e-9)) diff --git a/tl_scripts/torch_ref.py b/tl_scripts/torch_ref.py index 39890b36987e..a9c6451691dc 100644 --- a/tl_scripts/torch_ref.py +++ b/tl_scripts/torch_ref.py @@ -205,7 +205,8 @@ def set_seed(seed): batch = 8 head = 16 seq_len = 1024 -dim = 64 +dim_qk = 64 +dim_v = 448 blockM = seq_len blockN = 64 # q = torch.randn((BATCH, HEAD, SEQLEN, D), dtype=torch.float16) @@ -213,10 +214,11 @@ def set_seed(seed): # v = torch.randn((BATCH, HEAD, SEQLEN, D), dtype=torch.float16) # mask = torch.randn((HEAD, SEQLEN, SEQLEN), dtype=torch.float16) -shape = [batch, seq_len, head, dim] -q = torch.randn(shape, dtype=torch.float16) -k = torch.randn(shape, dtype=torch.float16) -v = torch.randn(shape, dtype=torch.float16) +qk_shape = [batch, seq_len, head, dim_qk] +v_shape = [batch, seq_len, head, dim_v] +q = torch.randn(qk_shape, dtype=torch.float16) +k = torch.randn(qk_shape, dtype=torch.float16) +v = torch.randn(v_shape, dtype=torch.float16) mask = torch.randn((head, seq_len, seq_len), dtype=torch.float16) def ref_program(q, k, v, mask): @@ -230,7 +232,7 @@ def test_program(q, k, v, mask): r = torch.zeros((batch, head, seq_len), dtype=torch.float16) r_new = torch.zeros((batch, head, seq_len), dtype=torch.float16) r_wo_clamp = torch.zeros((batch, head, seq_len), dtype=torch.float16) - acco = torch.zeros((batch, head, seq_len, dim), dtype=torch.float16) + acco = torch.zeros((batch, head, seq_len, dim_v), dtype=torch.float16) for i in range(seq_len // blockN): qk = torch.einsum('bqhd,bkhd->bhqk', q, k[:, i * blockN : (i + 1) * blockN, :, :]) qkm = qk * mask[:, :, i * blockN : (i + 1) * blockN].unsqueeze(0) From 7936d2b2280aadf1162aacc5e124ae128c70ddb6 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Thu, 26 Sep 2024 15:11:03 +0000 Subject: [PATCH 19/23] [tl] update mamba --- .gitmodules | 3 + 3rdparty/mamba | 1 + mamba | 1 + tl_scripts/mamba_example.py | 125 ++++++++++++------------------------ 4 files changed, 47 insertions(+), 83 deletions(-) create mode 160000 3rdparty/mamba create mode 160000 mamba diff --git a/.gitmodules b/.gitmodules index e9fefa469390..5e9a9e1fb0eb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -37,3 +37,6 @@ [submodule "3rdparty/flash-linear-attention"] path = 3rdparty/flash-linear-attention url = git@github.com:sustcsonglin/flash-linear-attention.git +[submodule "3rdparty/mamba"] + path = 3rdparty/mamba + url = git@github.com:state-spaces/mamba.git diff --git a/3rdparty/mamba b/3rdparty/mamba new file mode 160000 index 000000000000..62db608da60f --- /dev/null +++ b/3rdparty/mamba @@ -0,0 +1 @@ +Subproject commit 62db608da60f6fc790b8ed9f4b3225e95ca15fde diff --git a/mamba b/mamba new file mode 160000 index 000000000000..62db608da60f --- /dev/null +++ b/mamba @@ -0,0 +1 @@ +Subproject commit 62db608da60f6fc790b8ed9f4b3225e95ca15fde diff --git a/tl_scripts/mamba_example.py b/tl_scripts/mamba_example.py index 198eda87a185..9eca77c971e7 100644 --- a/tl_scripts/mamba_example.py +++ b/tl_scripts/mamba_example.py @@ -1,103 +1,62 @@ import argparse import torch +import torch.nn.functional as F from tvm import tl import tvm.tl.language as T from functools import partial - -def retnet(batch, heads, seq_len, dim, is_casual, block_M, block_N): - scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) - shape = [batch, seq_len, heads, dim] +chunk_size = 256 +def bmm_chunk(batch, seqlen, ngroups, k, block_M, block_N, block_K): dtype = "float16" accum_dtype = "float" - + nchunks = T.ceildiv(seqlen, chunk_size) @T.prim_func def main( - Q: T.Buffer(shape, dtype), - K: T.Buffer(shape, dtype), - V: T.Buffer(shape, dtype), - Output: T.Buffer(shape, dtype), + A: T.Buffer((batch, seqlen, ngroups, k), dtype), + B: T.Buffer((batch, seqlen, ngroups, k), dtype), + Output: T.Buffer((batch, nchunks, ngroups, chunk_size, chunk_size), dtype) ): - with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): - Q_shared = T.alloc_shared([block_M, dim], dtype) - Q_local = T.alloc_fragment([block_M, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim], dtype) - acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) - acc_o = T.alloc_fragment([block_M, dim], accum_dtype) - scores_max = T.alloc_fragment([block_M], accum_dtype) - scores_max_prev = T.alloc_fragment([block_M], accum_dtype) - scores_sum = T.alloc_fragment([block_M], accum_dtype) - lse = T.alloc_fragment([block_M], accum_dtype) - - T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) - T.fill(acc_o, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.fill(scores_max_prev, -T.infinity(accum_dtype)) - T.fill(lse, -T.infinity(accum_dtype)) - T.copy(Q_shared, Q_local) - loop_range = T.ceildiv(seq_len, block_N) + with T.Kernel(T.ceildiv(chunk_size, block_M) * T.ceildiv(chunk_size, block_N), batch, nchunks * ngroups, threads=128) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) + chunk_idx = bz // ngroups + group_idx = bz % ngroups + m_idx = bx // T.ceildiv(chunk_size, block_N) + n_idx = bx % T.ceildiv(chunk_size, block_N) + + loop_range = T.ceildiv(chunk_size, block_K) + T.clear(acc_o) for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) - T.clear(acc_s) - T.gemm(Q_local, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_M): - scores_max[i] = T.max(scores_max_prev[i], scores_max[i]) - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.exp(acc_s[i, j] - scores_max[i]) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - lse[i] = scores_max[i] + T.log(T.exp(lse[i] - scores_max[i]) + scores_sum[i]) - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] = acc_o[i, j] * T.exp(scores_max_prev[i] - scores_max[i]) - T.copy(scores_max, scores_max_prev) - T.copy(acc_s, acc_s_cast) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= T.exp(scores_max[i] - lse[i]) - T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + T.copy(A[by, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + group_idx, + k * block_K : (k + 1) * block_K], + A_shared) + T.copy(B[by, + chunk_idx * chunk_size + n_idx * block_N : chunk_idx * chunk_size + (n_idx + 1) * block_N, + group_idx, + k * block_K : (k + 1) * block_K], + B_shared) + T.gemm(A_shared, B_shared, acc_o, transpose_B=True) + T.copy(acc_o, Output[by, chunk_idx, group_idx, m_idx * block_M : (m_idx + 1) * block_M, n_idx * block_N : (n_idx + 1) * block_N]) return main +def ref_program(A, B): + from einops import rearrange, repeat + seqlen = A.shape[1] + nchunks = (seqlen + chunk_size - 1) // chunk_size -def ref_program(Q, K, V, casual): - qk = torch.matmul(Q.permute(0, 2, 1, 3), K.permute(0, 2, 3, 1)) # [B, H, SEQLEN, SEQLEN] - m = qk.max(dim=-1, keepdim=True).values - p = torch.exp(qk - m) - s = p / p.sum(dim=-1, keepdim=True) - o = torch.matmul(s.to(torch.float16), V.permute(0, 2, 1, 3)) # [B, H, SEQLEN, dim] - return o.permute(0, 2, 1, 3) + A = rearrange(A, "b (c l) g d -> b c l g d", c=nchunks) + B = rearrange(B, "b (c l) g d -> b c l g d", c=nchunks) + return torch.einsum("bclgd,bcsgd->bcgls", A, B) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='Batch size') - parser.add_argument('--h', type=int, default=12, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=2048, help='Context size') - parser.add_argument('--d_head', type=int, default=256, help='Head dimension') - parser.add_argument('--casual', type=bool, default=True, help='Casual flag') - args = parser.parse_args() - BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head - casual = args.casual - flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD - total_flops = 2 * flops_per_matmul - if casual: - total_flops *= 0.5 - BLOCK_M = 64 - BLOCK_N = 64 if D_HEAD <= 128 else 32 - program = retnet(BATCH, H, N_CTX, D_HEAD, casual, BLOCK_M, BLOCK_N) - ref_program = partial(ref_program, casual=casual) + BATCH, SEQLEN, NGROUPS, DSTATE = 8, 4096, 16, 64 + block_M, block_N, block_K = 64, 64, 64 + program = bmm_chunk(BATCH, SEQLEN, NGROUPS, DSTATE, block_M, block_N, block_K) mod, params = tl.lower(program) - mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) - mod.assert_allclose(ref_program, rtol=0.1, atol=0.1) - - # latency = mod.do_bench(ref_program, n_warmup=10, n_repeat=1) - # print("torch: {:.2f} ms".format(latency)) - # print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) - latency = mod.do_bench(mod, n_warmup=10, n_repeat=5) - print("tl: {:.2f} ms".format(latency)) - print("tl: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Normal) + mod.assert_allclose(ref_program, rtol=0.1, atol=0.1) \ No newline at end of file From ff772114426029b04f57817d37a3b2e7fcb37100 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Fri, 27 Sep 2024 05:22:36 +0000 Subject: [PATCH 20/23] [tl] update gen_configs --- tl_pipeline/gen_configs.py | 202 +++++++++++++++++++++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 tl_pipeline/gen_configs.py diff --git a/tl_pipeline/gen_configs.py b/tl_pipeline/gen_configs.py new file mode 100644 index 000000000000..7dfc12d957d3 --- /dev/null +++ b/tl_pipeline/gen_configs.py @@ -0,0 +1,202 @@ +from typing import List +import networkx as nx +from itertools import permutations + +class Op: + def __init__(self, name: str, reads: List[List[str]], writes: List[List[str]], is_async: List[int]): + self.name = name + self.reads = reads + self.writes = writes + self.is_async = is_async + + def __str__(self): + return f'{self.name}' + + def __repr__(self): + return str(self) + +MMA0 = Op('MMA0', [[None], ['Q_shared', 'K_shared', 'acc_s']], [['acc_s'], ['acc_s']], [1]) +Softmax = Op('Softmax', [['scores_max'], [None], ['acc_s'], ['scores_max', 'scores_max_prev'], ['acc_s', 'scores_max'],['acc_s'],['logsum','scores_scale','scores_sum'], ['acc_s']], [['scores_max_prev'], ['scores_max'], ['scores_max'], ['scores_scale'], ['acc_s'], ['scores_sum'], ['logsum'], ['acc_s_cast']], []) +Rescale = Op('Rescale', [['acc_o', 'scores_scale']], [['acc_o']], []) +MMA1 = Op('MMA1', [['V_shared', 'acc_s_cast', 'acc_o']], [['acc_o']], [0]) + +graph = nx.DiGraph() +graph.add_edge(0, 1) +graph.add_edge(1, 2) +graph.add_edge(2, 3) +nodes = [MMA0, Softmax, Rescale, MMA1] + +for edge in graph.edges: + print(edge) + +max_stream = 2 +def get_issue_info(graph: nx.DiGraph): + n_nodes = len(graph.nodes) + print("n_nodes:", n_nodes) + def get_order(): + # order of op0 must be 0 + nodes = [i for i in range(1, n_nodes)] + partial_all_orders = [list(order) for order in permutations(nodes)] + all_orders = [[0] + order for order in partial_all_orders] + return all_orders + + def validate(order, stage) -> bool: + for i in range(n_nodes): + # print("i:", i) + # print("nx.descendants(graph, i):", [(s, type(s)) for s in nx.descendants(graph, i)]) + for j in range(n_nodes): + if (i == j): + continue + if j not in nx.descendants(graph, i): + continue + if stage[i] < stage[j]: + continue + if stage[i] == stage[j]: + if order[i] > order[j]: + return False + if stage[i] > stage[j]: + return False + return True + + def get_stage(order): + # stage of op0 must be 0 + valid_stages = [] + def gen(n, max_value): + if n == 0: + return [[]] + res = gen(n - 1, max_value) + return [item + [i] for item in res for i in range(max_value + 1)] + + partial_all_stages = gen(n_nodes - 1, max_stream) + all_stages = [[0] + item for item in partial_all_stages] + # print("all_stages:", len(all_stages)) + for stage in all_stages: + if validate(order, stage): + valid_stages.append(stage) + # print("valid_stages:", len(valid_stages)) + return valid_stages + + ans = [] + orders = get_order() + + # print("orders:", orders) + for order in orders: + stages = get_stage(order) + for stage in stages: + ans.append((order, stage)) + return ans + +def get_sync_info(graph: nx.DiGraph, issue_info): + reads = [] + writes = [] + asyncs = [] + n_nodes = len(graph.nodes) + def extract_stmts(graph): + cur_id = 0 + for i in range(n_nodes): + assert len(nodes[i].reads) == len(nodes[i].writes) + reads.extend(nodes[i].reads) + writes.extend(nodes[i].writes) + if len(nodes[i].is_async) > 0: + for async_stmt in nodes[i].is_async: + asyncs.append((i, async_stmt)) + cur_id += len(nodes[i].reads) + return reads, writes, asyncs + + def has_intersects(l0, l1): + for item in l0: + if item in l1: + return True + return False + + def get_valid_pos(sync_node, mma_node, async_stmt): + orders = issue_info[0] + stages = issue_info[1] + #Step 1. Check if sync before this node is possible + mma_reads = nodes[mma_node].reads[async_stmt] + mma_writes = nodes[mma_node].writes[async_stmt] + if orders[sync_node] > orders[mma_node]: + for mid_node in range(n_nodes): + if orders[mid_node] >= orders[sync_node] \ + or orders[mid_node] <= orders[mma_node]: + continue + op_reads = [r for rs in nodes[mid_node].reads for r in rs] + op_writes = [w for ws in nodes[mid_node].writes for w in ws] + if stages[mid_node] == stages[mma_node] \ + and has_intersects(mma_writes, op_reads): + return None + if orders[sync_node] <= orders[mma_node]: + # Check from mma_node to the end + for mid_node in range(n_nodes): + if orders[mid_node] <= orders[mma_node]: + continue + op_reads = [r for rs in nodes[mid_node].reads for r in rs] + op_writes = [w for ws in nodes[mid_node].writes for w in ws] + if stages[mid_node] == stages[mma_node] \ + and has_intersects(mma_writes, op_reads): + return None + # Check from the start to sync node + for mid_node in range(n_nodes): + if orders[mid_node] >= orders[sync_node]: + continue + op_reads = [r for rs in nodes[mid_node].reads for r in rs] + op_writes = [w for ws in nodes[mid_node].writes for w in ws] + if stages[mid_node] == stages[mma_node] + 1 \ + and has_intersects(mma_writes, op_reads): + return None + + # Step 2. Find the lateset possible sync position in the node + stmt_num = len(nodes[sync_node].reads) + valid_pos = -1 + for i in range(stmt_num): + stmt_reads = nodes[sync_node].reads[i] + stmt_writes = nodes[sync_node].writes[i] + if orders[sync_node] < orders[mma_node] \ + and stages[sync_node] == stages[mma_node] + 1 \ + and has_intersects(mma_writes, stmt_reads): + break + if orders[sync_node] > orders[mma_node] \ + and stages[sync_node] == stages[mma_node] \ + and has_intersects(mma_writes, stmt_reads): + break + valid_pos += 1 + + if valid_pos == -1: + return None + return valid_pos + + async_stmt_id = 0 + # we try to put the sync as close to the end as possible + extract_stmts(graph) + print("reads:", reads) + print("writes:", writes) + print("asyncs:", asyncs) + all_sync_pos = [] + for node, async_stmt in asyncs: + sync_pos_list = [] + pre_stmt_num = 0 + for sync_node in range(n_nodes): + sync_pos = get_valid_pos(sync_node, node, async_stmt) + # print("sync_pos:", sync_pos + pre_stmt_num if sync_pos is not None else None) + if sync_pos is not None: + sync_pos_list.append(sync_pos + pre_stmt_num) + pre_stmt_num += len(nodes[sync_node].reads) + # print("sync_pos_list:", sync_pos_list) + all_sync_pos.append(sync_pos_list) + print("all_sync_pos:", all_sync_pos) + return all_sync_pos + +def gen_configs(graph: nx.DiGraph): + config = [] + issue_infos = get_issue_info(graph) + print("issue_infos:", issue_infos) + print("issue_infos:", len(issue_infos)) + for issue_info in issue_infos: + issue_info = ([0,1,2,3], [0,0,0,0]) + syncs = get_sync_info(graph, issue_info) + break + for sync in syncs: + config.append((issue_info, sync)) + return config + +pipeline_configs = gen_configs(graph) \ No newline at end of file From 1f5a5f4390ee1d87fd771cfb7a34626b76472a62 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Wed, 2 Oct 2024 06:49:03 +0000 Subject: [PATCH 21/23] [tl] update --- .gitignore | 7 +- mamba | 1 - python/tvm/tl/autotuner.py | 6 +- python/tvm/tl/modified_code.cu | 191 --- testing/mamba_triton.py | 81 ++ tl_scripts/dequant_gemm.py | 144 +++ tl_scripts/mamba_example.py | 1116 ++++++++++++++++- tl_scripts/mha_test.py | 191 --- .../{profile.py => profile_workloads.py} | 0 tl_scripts/retnet_example.py | 8 + tl_verify/compile.py | 46 - tl_verify/cuda_interface.cpp | 57 - tl_verify/fa_kernel.cu | 429 ------- tl_verify/fa_kernel.hpp | 26 - tl_verify/fa_no_tma.cu | 199 --- tl_verify/main.py | 91 -- tl_verify/setup.py | 110 -- 17 files changed, 1323 insertions(+), 1380 deletions(-) delete mode 160000 mamba delete mode 100644 python/tvm/tl/modified_code.cu create mode 100644 testing/mamba_triton.py create mode 100644 tl_scripts/dequant_gemm.py delete mode 100644 tl_scripts/mha_test.py rename tl_scripts/{profile.py => profile_workloads.py} (100%) delete mode 100644 tl_verify/compile.py delete mode 100644 tl_verify/cuda_interface.cpp delete mode 100644 tl_verify/fa_kernel.cu delete mode 100644 tl_verify/fa_kernel.hpp delete mode 100644 tl_verify/fa_no_tma.cu delete mode 100644 tl_verify/main.py delete mode 100644 tl_verify/setup.py diff --git a/.gitignore b/.gitignore index 55718420119d..21500360b2f1 100644 --- a/.gitignore +++ b/.gitignore @@ -281,4 +281,9 @@ gallery/how_to/work_with_microtvm/micro_tvmc.py cmake/config.cmake */reports/* play.py -*.ptx \ No newline at end of file +*.ptx +*.ncu-rep +tl_verify/* +*/modified_code.cu +modified_code.cu +code_replace.py \ No newline at end of file diff --git a/mamba b/mamba deleted file mode 160000 index 62db608da60f..000000000000 --- a/mamba +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 62db608da60f6fc790b8ed9f4b3225e95ca15fde diff --git a/python/tvm/tl/autotuner.py b/python/tvm/tl/autotuner.py index e9dd4c90755c..b228a200ad6a 100644 --- a/python/tvm/tl/autotuner.py +++ b/python/tvm/tl/autotuner.py @@ -107,6 +107,7 @@ def jit( out_idx: List[int], supply_type: tl.TensorSupplyType = tl.TensorSupplyType.Normal, ref_prog: Callable = None, + check_close: bool = True, rtol: float = 1e-5, atol: float = 1e-5, profiler: str = "torch" @@ -119,10 +120,11 @@ def decorator(*args, **kwargs) -> float: nonlocal ref_latency_cache mod, params = tl.lower(fn(*args, **kwargs)) mod = tl.Profiler(mod, params, out_idx, supply_type) - mod.assert_allclose(ref_prog, rtol=rtol, atol=atol) + if check_close: + mod.assert_allclose(ref_prog, rtol=rtol, atol=atol) latency = mod.do_bench(mod.func, n_warmup=10, n_repeat=10, profiler=profiler) if ref_latency_cache is None and ref_prog is not None: - ref_latency_cache = mod.do_bench(ref_prog, n_warmup=10, n_repeat=10, profiler=profiler) + ref_latency_cache = mod.do_bench(ref_prog, n_warmup=10, n_repeat=10, profiler="torch") return latency, ref_latency_cache return decorator return wrapper \ No newline at end of file diff --git a/python/tvm/tl/modified_code.cu b/python/tvm/tl/modified_code.cu deleted file mode 100644 index f8d90d99de36..000000000000 --- a/python/tvm/tl/modified_code.cu +++ /dev/null @@ -1,191 +0,0 @@ -#include -#include -#include -#include -#include - -extern "C" __global__ void __launch_bounds__(384) main_kernel(__grid_constant__ const CUtensorMap K_desc, __grid_constant__ const CUtensorMap Output_desc, __grid_constant__ const CUtensorMap Q_desc, __grid_constant__ const CUtensorMap V_desc, __grid_constant__ const CUtensorMap mask_desc) { - extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; - float r[2]; - float r_new[2]; - float r_wo_clamp[2]; - float acc_o[112]; - float acc_s[16]; - half_t mask_local[16]; - float acc_s_1[32]; - float abs_sum[2]; - half_t acc_s_cast[32]; - __shared__ uint64_t _mbarrier[11]; - if (((int)threadIdx.x) == 0) { - tl::prefetch_tma_descriptor(Q_desc); - tl::prefetch_tma_descriptor(K_desc); - tl::prefetch_tma_descriptor(mask_desc); - tl::prefetch_tma_descriptor(V_desc); - tl::prefetch_tma_descriptor(Output_desc); - tl::mbarrier_init(_mbarrier[0], 128); - tl::mbarrier_init(_mbarrier[1], 128); - tl::mbarrier_init(_mbarrier[2], 128); - tl::mbarrier_init(_mbarrier[3], 256); - tl::mbarrier_init(_mbarrier[4], 256); - tl::mbarrier_init(_mbarrier[5], 256); - tl::mbarrier_init(_mbarrier[6], 128); - tl::mbarrier_init(_mbarrier[7], 256); - tl::mbarrier_init(_mbarrier[8], 256); - tl::mbarrier_init(_mbarrier[9], 256); - tl::mbarrier_init(_mbarrier[10], 256); - } - __syncthreads(); - if (256 <= ((int)threadIdx.x)) { - tl::warpgroup_reg_dealloc<24>(); - // if (((int)threadIdx.x) == 256) { - // tl::mbarrier_expect_tx(_mbarrier[6], 32768); - // } - // if (((int)threadIdx.x) == 256) { - // tl::tma_load(Q_desc, _mbarrier[6], (&(((half_t*)buf_dyn_shmem)[12288])), 0, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); - // tl::tma_load(Q_desc, _mbarrier[6], (&(((half_t*)buf_dyn_shmem)[16384])), 64, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); - // tl::tma_load(Q_desc, _mbarrier[6], (&(((half_t*)buf_dyn_shmem)[20480])), 128, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); - // tl::tma_load(Q_desc, _mbarrier[6], (&(((half_t*)buf_dyn_shmem)[24576])), 192, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); - // } - // tl::mbarrier_arrive(_mbarrier[6]); - // for (int k = 0; k < 128; ++k) { - // tl::mbarrier_wait(_mbarrier[3], ((k & 1) ^ 1)); - // if (((int)threadIdx.x) == 256) { - // tl::mbarrier_expect_tx(_mbarrier[0], 32768); - // } - // if (((int)threadIdx.x) == 256) { - // tl::tma_load(K_desc, _mbarrier[0], (&(((half_t*)buf_dyn_shmem)[28672])), 0, ((int)blockIdx.y), (k * 64), 0); - // tl::tma_load(K_desc, _mbarrier[0], (&(((half_t*)buf_dyn_shmem)[32768])), 64, ((int)blockIdx.y), (k * 64), 0); - // tl::tma_load(K_desc, _mbarrier[0], (&(((half_t*)buf_dyn_shmem)[36864])), 128, ((int)blockIdx.y), (k * 64), 0); - // tl::tma_load(K_desc, _mbarrier[0], (&(((half_t*)buf_dyn_shmem)[40960])), 192, ((int)blockIdx.y), (k * 64), 0); - // } - // tl::mbarrier_arrive(_mbarrier[0]); - // tl::mbarrier_wait(_mbarrier[4], ((k & 1) ^ 1)); - // if (((int)threadIdx.x) == 256) { - // tl::mbarrier_expect_tx(_mbarrier[1], 8192); - // } - // if (((int)threadIdx.x) == 256) { - // tl::tma_load(mask_desc, _mbarrier[1], (&(((half_t*)buf_dyn_shmem)[0])), (k * 64), (((int)blockIdx.x) * 64), ((int)blockIdx.y)); - // } - // tl::mbarrier_arrive(_mbarrier[1]); - // tl::mbarrier_wait(_mbarrier[5], ((k & 1) ^ 1)); - // if (((int)threadIdx.x) == 256) { - // tl::mbarrier_expect_tx(_mbarrier[2], 57344); - // } - // if (((int)threadIdx.x) == 256) { - // tl::tma_load(V_desc, _mbarrier[2], (&(((half_t*)buf_dyn_shmem)[45056])), 0, ((int)blockIdx.y), (k * 64), 0); - // tl::tma_load(V_desc, _mbarrier[2], (&(((half_t*)buf_dyn_shmem)[49152])), 64, ((int)blockIdx.y), (k * 64), 0); - // tl::tma_load(V_desc, _mbarrier[2], (&(((half_t*)buf_dyn_shmem)[53248])), 128, ((int)blockIdx.y), (k * 64), 0); - // tl::tma_load(V_desc, _mbarrier[2], (&(((half_t*)buf_dyn_shmem)[57344])), 192, ((int)blockIdx.y), (k * 64), 0); - // tl::tma_load(V_desc, _mbarrier[2], (&(((half_t*)buf_dyn_shmem)[61440])), 256, ((int)blockIdx.y), (k * 64), 0); - // tl::tma_load(V_desc, _mbarrier[2], (&(((half_t*)buf_dyn_shmem)[65536])), 320, ((int)blockIdx.y), (k * 64), 0); - // tl::tma_load(V_desc, _mbarrier[2], (&(((half_t*)buf_dyn_shmem)[69632])), 384, ((int)blockIdx.y), (k * 64), 0); - // } - // tl::mbarrier_arrive(_mbarrier[2]); - // } - } else { - tl::warpgroup_reg_alloc<240>(); - #pragma unroll - for (int i = 0; i < 2; ++i) { - r[i] = 0.000000e+00f; - } - #pragma unroll - for (int i_1 = 0; i_1 < 2; ++i_1) { - r_new[i_1] = 0.000000e+00f; - } - #pragma unroll - for (int i_2 = 0; i_2 < 2; ++i_2) { - r_wo_clamp[i_2] = 0.000000e+00f; - } - #pragma unroll - for (int i_3 = 0; i_3 < 112; ++i_3) { - acc_o[i_3] = 0.000000e+00f; - } - tl::fence_proxy_async(); - // tl::mbarrier_wait(_mbarrier[6], 0); - #pragma unroll 1 - for (int k_1 = 0; k_1 < 128; ++k_1) { - #pragma unroll - for (int i_4 = 0; i_4 < 16; ++i_4) { - acc_s[i_4] = 0.000000e+00f; - } - tl::fence_proxy_async(); - // tl::mbarrier_wait(_mbarrier[0], (k_1 & 1)); - tl::gemm_ss<64, 64, 256, 4, 2, 0, 1>((&(((half_t*)buf_dyn_shmem)[12288])), (&(((half_t*)buf_dyn_shmem)[28672])), (&(acc_s[0]))); - // tl::mbarrier_arrive(_mbarrier[3]); - // tl::mbarrier_wait(_mbarrier[1], (k_1 & 1)); - #pragma unroll - for (int i_5 = 0; i_5 < 2; ++i_5) { - tl::ptx_ldmatrix_x4((&(((half_t*)buf_dyn_shmem)[(((((((((int)threadIdx.x) & 127) >> 5) * 1024) + ((((int)threadIdx.x) & 15) * 64)) + ((((int)threadIdx.x) >> 7) * 32)) + (i_5 * 16)) + (((((int)threadIdx.x) & 31) >> 4) * 8))])), (&(mask_local[(i_5 * 8)]))); - } - tl::fence_proxy_async(); - // tl::mbarrier_arrive(_mbarrier[4]); - #pragma unroll - for (int i_6 = 0; i_6 < 16; ++i_6) { - acc_s[i_6] = (acc_s[i_6] * ((float)mask_local[i_6])); - } - tl::syncthreads_partial(_mbarrier[7]); - #pragma unroll - for (int i_7 = 0; i_7 < 8; ++i_7) { - ((float2*)buf_dyn_shmem)[(((((((((((int)threadIdx.x) & 127) >> 5) * 512) + ((i_7 & 1) * 256)) + (((((int)threadIdx.x) & 31) >> 2) * 32)) + ((((int)threadIdx.x) >> 7) * 16)) + ((i_7 >> 1) * 4)) + (((int)threadIdx.x) & 3)) + 1024)] = *(float2*)(acc_s + (i_7 * 2)); - } - tl::syncthreads_partial(_mbarrier[8]); - #pragma unroll - for (int i_8 = 0; i_8 < 16; ++i_8) { - *(float2*)(acc_s_1 + (i_8 * 2)) = ((float2*)buf_dyn_shmem)[((((((((((int)threadIdx.x) & 127) >> 5) * 512) + ((i_8 & 1) * 256)) + (((((int)threadIdx.x) & 31) >> 2) * 32)) + ((i_8 >> 1) * 4)) + (((int)threadIdx.x) & 3)) + 1024)]; - } - #pragma unroll - for (int i_9 = 0; i_9 < 2; ++i_9) { - abs_sum[i_9] = 0.000000e+00f; - #pragma unroll - for (int rv = 0; rv < 16; ++rv) { - abs_sum[i_9] = (abs_sum[i_9] + max(acc_s_1[((((rv & 7) * 4) + (i_9 * 2)) + (rv >> 3))], (0.000000e+00f - acc_s_1[((((rv & 7) * 4) + (i_9 * 2)) + (rv >> 3))]))); - } - abs_sum[i_9] = tl::AllReduce::run(abs_sum[i_9]); - } - #pragma unroll - for (int i_10 = 0; i_10 < 2; ++i_10) { - r_wo_clamp[i_10] = (r_wo_clamp[i_10] + abs_sum[i_10]); - } - #pragma unroll - for (int i_11 = 0; i_11 < 2; ++i_11) { - r_new[i_11] = max(r_wo_clamp[i_11], 1.000000e+00f); - } - #pragma unroll - for (int i_12 = 0; i_12 < 112; ++i_12) { - acc_o[i_12] = ((0 < k_1) ? ((acc_o[i_12] * r[((i_12 & 3) >> 1)]) / r_new[((i_12 & 3) >> 1)]) : acc_o[i_12]); - } - #pragma unroll - for (int i_13 = 0; i_13 < 2; ++i_13) { - r[i_13] = r_new[i_13]; - } - #pragma unroll - for (int i_14 = 0; i_14 < 32; ++i_14) { - acc_s_1[i_14] = (acc_s_1[i_14] / r_new[((i_14 & 3) >> 1)]); - } - #pragma unroll - for (int i_15 = 0; i_15 < 32; ++i_15) { - acc_s_cast[i_15] = ((half_t)acc_s_1[i_15]); - } - tl::fence_proxy_async(); - // tl::mbarrier_wait(_mbarrier[2], (k_1 & 1)); - tl::gemm_rs<64, 448, 64, 4, 2, 0, 0>((&(acc_s_cast[0])), (&(((half_t*)buf_dyn_shmem)[45056])), (&(acc_o[0]))); - // tl::mbarrier_arrive(_mbarrier[5]); - } - tl::syncthreads_partial(_mbarrier[9]); - #pragma unroll - for (int i_16 = 0; i_16 < 14; ++i_16) { - tl::ptx_stmatrix_x4((&(((half_t*)buf_dyn_shmem)[(((((((((((((int)threadIdx.x) >> 7) * 7) + (i_16 >> 1)) >> 1) * 4096) + (((((int)threadIdx.x) & 127) >> 5) * 1024)) + ((((int)threadIdx.x) & 15) * 64)) + (((((((int)threadIdx.x) & 7) >> 2) + (((((int)threadIdx.x) >> 7) + (i_16 >> 1)) & 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 3) >> 1) + (i_16 & 1)) & 1) * 16)) + (((((((int)threadIdx.x) & 31) >> 4) + (((int)threadIdx.x) & 1)) & 1) * 8)) + 45056)])), __pack_half2(((half_t)acc_o[(i_16 * 8)]), ((half_t)acc_o[((i_16 * 8) + 1)])), __pack_half2(((half_t)acc_o[((i_16 * 8) + 2)]), ((half_t)acc_o[((i_16 * 8) + 3)])), __pack_half2(((half_t)acc_o[((i_16 * 8) + 4)]), ((half_t)acc_o[((i_16 * 8) + 5)])), __pack_half2(((half_t)acc_o[((i_16 * 8) + 6)]), ((half_t)acc_o[((i_16 * 8) + 7)]))); - } - tl::fence_proxy_async(); - tl::syncthreads_partial(_mbarrier[10]); - if (((int)threadIdx.x) == 0) { - tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[45056])), 0, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); - tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[49152])), 64, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); - tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[53248])), 128, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); - tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[57344])), 192, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); - tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[61440])), 256, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); - tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[65536])), 320, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); - tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[69632])), 384, ((int)blockIdx.y), (((int)blockIdx.x) * 64), 0); - } - } -} \ No newline at end of file diff --git a/testing/mamba_triton.py b/testing/mamba_triton.py new file mode 100644 index 000000000000..a5982fd054cd --- /dev/null +++ b/testing/mamba_triton.py @@ -0,0 +1,81 @@ +import torch +import triton +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from mamba_ssm.ops.triton.selective_state_update import selective_state_update + +configs = [triton.testing.Benchmark( + x_names=["batch", "seq_len", "nheads", "headdim", "ngroups", "dstate", "chunk_size"], + x_vals=[ + (64, 4096, 64, 64, 8, 64, 256), + ], + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. + line_vals=[""], + line_names=[""], + styles=[("green", "-"), ("blue", "-")], + ylabel="TFLOPS", # Label name for the y-axis + plot_name="mamba2-performance-fp16", + args={}, + )] + +# @triton.testing.perf_report(configs) +# def benchmark(batch, seq_len, nheads, headdim, ngroups, dstate, chunk_size, provider): +# warmup = 25 +# rep = 100 +# x = torch.empty((batch, seq_len, nheads, headdim), device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) +# dt = torch.empty((batch, seq_len, nheads), device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) +# A = torch.empty((nheads), device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) +# B = torch.empty((batch, seq_len, ngroups, dstate), device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) +# C = torch.empty((batch, seq_len, ngroups, dstate), device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) + +# quantiles = [0.5, 0.2, 0.8] +# ms, min_ms, max_ms = triton.testing.do_bench(lambda: mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size), warmup=warmup, rep=rep, quantiles=quantiles) +# flops_chunk_cumsum_fwd = 0 +# flops_chunk_state_fwd = 2.0 * batch * seq_len * nheads * headdim * dstate +# flops_state_passing_fwd = 0 +# flops_bmm_chunk_fwd = 2.0 * batch * ngroups * dstate * seq_len * chunk_size +# flops_chunk_scan_fwd = 2.0 * batch * seq_len * chunk_size * nheads * headdim + 2.0 * batch * seq_len * nheads * headdim * dstate +# total_flops = flops_chunk_cumsum_fwd + flops_chunk_state_fwd + flops_state_passing_fwd + flops_bmm_chunk_fwd + flops_chunk_scan_fwd +# perf = lambda ms: total_flops * 1e-12 / (ms * 1e-3) +# return perf(ms), perf(max_ms), perf(min_ms) + + +# benchmark.run(show_plots=True, print_data=True) + + +@triton.testing.perf_report(configs) +def benchmark(batch, seq_len, nheads, headdim, ngroups, dstate, chunk_size, provider): + warmup = 25 + rep = 100 + x = torch.empty((batch, seq_len, nheads, headdim), device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) + dt = torch.empty((batch, seq_len, nheads), device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) + A = torch.empty((nheads), device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) + B = torch.empty((batch, seq_len, ngroups, dstate), device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) + C = torch.empty((batch, seq_len, ngroups, dstate), device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) + + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench(lambda: mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size), warmup=warmup, rep=rep, quantiles=quantiles) + flops_chunk_cumsum_fwd = 0 + flops_chunk_state_fwd = 2.0 * batch * seq_len * nheads * headdim * dstate + flops_state_passing_fwd = 0 + flops_bmm_chunk_fwd = 2.0 * batch * ngroups * dstate * seq_len * chunk_size + flops_chunk_scan_fwd = 2.0 * batch * seq_len * chunk_size * nheads * headdim + 2.0 * batch * seq_len * nheads * headdim * dstate + total_flops = flops_chunk_cumsum_fwd + flops_chunk_state_fwd + flops_state_passing_fwd + flops_bmm_chunk_fwd + flops_chunk_scan_fwd + perf = lambda ms: total_flops * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +# benchmark.run(show_plots=True, print_data=True) +# benchmark(64, 4096, 64, 64, 8, 64, 256, "") + +from transformers import AutoTokenizer, Mamba2Model +import torch + +tokenizer = AutoTokenizer.from_pretrained("mistralai/mamba-codestral-7B-v0.1") +model = Mamba2Model.from_pretrained("mistralai/mamba-codestral-7B-v0.1") + +inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") +outputs = model(**inputs) + +last_hidden_states = outputs.last_hidden_state \ No newline at end of file diff --git a/tl_scripts/dequant_gemm.py b/tl_scripts/dequant_gemm.py new file mode 100644 index 000000000000..d35d6ad4eed1 --- /dev/null +++ b/tl_scripts/dequant_gemm.py @@ -0,0 +1,144 @@ +import tvm +from tvm import tl + +def _tir_packed_to_unsigned_convert(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tvm.tir.const((1 << nbit) - 1, storage_dtype) + return ((val >> (pos * nbit).astype(storage_dtype)) & mask).astype(dtype) + + return f_convert + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + dtypeAB, + dtypeC, + accum_dtype, + num_stages, + threads, + num_bits=4, +): + num_elems_per_byte = 8 // num_bits + storage_dtype = "int8" + A_shape = (M, K) + B_shape = (N, K // num_elems_per_byte) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + + import tvm.tl.language as T + + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), dtypeC), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment([8], storage_dtype) + B_dequantize_local = T.alloc_fragment([16], dtypeAB) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, dtypeAB) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 16)): + for t in T.thread_binding(0, threads, thread="threadIdx.x"): + for v in T.vectorized(0, 16): + vi = (i * threads * 16 + t * 16 + v) // (block_K // num_elems_per_byte) + vj = (i * threads * 16 + t * 16 + v) % (block_K // num_elems_per_byte) + B_shared[vi, vj] = B[bx * block_N + vi, + k * block_K // num_elems_per_byte + vj,] + + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 4)): + for t in T.thread_binding(0, threads, thread="threadIdx.x"): + for v in T.vectorized(0, 4): + vi = (i * threads * 4 + t * 4 + v) // (block_K // num_elems_per_byte) + vj = (i * threads * 4 + t * 4 + v) % (block_K // num_elems_per_byte) + B_local[v] = B_shared[vi, vj] + for v in T.serial(0, 8): + B_dequantize_local[v] = _tir_packed_to_unsigned_convert("int", 8)( + num_bits, + B_local[v // 2], + v % 2, + dtype=dtypeAB, + ) + for v in T.vectorized(0, 8): + vi = (i * threads * 8 + t * 8 + v) // (block_K) + vj = (i * threads * 8 + t * 8 + v) % (block_K) + B_dequantize_shared[vi, vj] = B_dequantize_local[v] + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + dtypeAB, + dtypeC, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + dtypeAB, + dtypeC, + dtypeAccum, + num_stages, + num_threads, + ) + print(program) + + mod, params = tl.lower(program) + mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) + + out = mod.run_once() + + print(f"output is {out}") + + def ref_program(A, qB): + import torch + + B = ( + torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, + dtype=torch.half).to(torch.half).to(A.device)) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + mod.assert_allclose(ref_program) + + +def test_run_dequantize_gemm(): + run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 64, num_threads=128) + # run_gemm(256, 256, 256, "float16", "float16", "float32", 128, 128, 64, num_threads=128) + + +if __name__ == "__main__": + # bitblas.testing.main() + test_run_dequantize_gemm() diff --git a/tl_scripts/mamba_example.py b/tl_scripts/mamba_example.py index 9eca77c971e7..6ff4a696494b 100644 --- a/tl_scripts/mamba_example.py +++ b/tl_scripts/mamba_example.py @@ -3,60 +3,1104 @@ import torch.nn.functional as F from tvm import tl import tvm.tl.language as T +from tvm.tl.autotuner import * from functools import partial +from einops import rearrange, repeat +import triton +import itertools chunk_size = 256 -def bmm_chunk(batch, seqlen, ngroups, k, block_M, block_N, block_K): + +#################################################################################################### +# bmm_chunk +#################################################################################################### + +# def bmm_chunk(batch, seqlen, ngroups, dstate, block_M = None, block_N = None, block_K = None, num_stages = None, thread_num = None): +# dtype = "float16" +# accum_dtype = "float" +# nchunks = T.ceildiv(seqlen, chunk_size) +# @T.prim_func +# def main( +# A: T.Buffer((batch, seqlen, ngroups, dstate), dtype), +# B: T.Buffer((batch, seqlen, ngroups, dstate), dtype), +# Output: T.Buffer((batch, nchunks, ngroups, chunk_size, chunk_size), dtype) +# ): +# with T.Kernel(T.ceildiv(chunk_size, block_M) * T.ceildiv(chunk_size, block_N), batch, nchunks * ngroups, threads=thread_num) as (bx, by, bz): +# A_shared = T.alloc_shared((block_M, block_K), dtype) +# B_shared = T.alloc_shared((block_N, block_K), dtype) +# acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) +# acc_o_shared = T.alloc_shared((block_M, block_N), dtype) +# chunk_idx = bz // ngroups +# group_idx = bz % ngroups +# m_idx = bx // T.ceildiv(chunk_size, block_N) +# n_idx = bx % T.ceildiv(chunk_size, block_N) + +# # T.annotate_layout({acc_o_shared: tl.layout.make_swizzled_layout(acc_o_shared)}) + +# loop_range = T.ceildiv(dstate, block_K) +# T.clear(acc_o) +# for k in T.Pipelined(loop_range, num_stages=num_stages): +# T.copy(A[by, +# chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, +# group_idx, +# k * block_K : (k + 1) * block_K], +# A_shared) +# T.copy(B[by, +# chunk_idx * chunk_size + n_idx * block_N : chunk_idx * chunk_size + (n_idx + 1) * block_N, +# group_idx, +# k * block_K : (k + 1) * block_K], +# B_shared) +# T.gemm(A_shared, B_shared, acc_o, transpose_B=True) +# T.copy(acc_o, acc_o_shared) +# T.copy(acc_o_shared, Output[by, chunk_idx, group_idx, m_idx * block_M : (m_idx + 1) * block_M, n_idx * block_N : (n_idx + 1) * block_N]) + +# return main + +# def bmm_triton(A, B): +# from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd +# return _bmm_chunk_fwd(A, B, chunk_size) + +# def bmm_ref_program(A, B): +# seqlen = A.shape[1] +# nchunks = (seqlen + chunk_size - 1) // chunk_size + +# A = rearrange(A, "b (c l) g d -> b c l g d", c=nchunks) +# B = rearrange(B, "b (c l) g d -> b c l g d", c=nchunks) +# return torch.einsum("bclgd,bcsgd->bcgls", A, B) + +def bmm_chunk(batch, seqlen, ngroups, dstate): + + def bmm_ref_program(A, B): + seqlen = A.shape[1] + nchunks = (seqlen + chunk_size - 1) // chunk_size + + A = rearrange(A, "b (c l) g d -> b c l g d", c=nchunks) + B = rearrange(B, "b (c l) g d -> b c l g d", c=nchunks) + + return torch.einsum("bclgd,bcsgd->bcgls", A, B) + + def bmm_triton(A, B): + from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd + return _bmm_chunk_fwd(A, B, chunk_size) + + def get_configs(): + block_M = [64, 128] + block_N = [32, 64, 128] + block_K = [32, 64] + num_stages = [1, 2] + _configs = list(itertools.product(block_M, block_N, block_K, num_stages)) + + configs = [ + {'block_M': c[0], 'block_N': c[1], 'block_K': c[2], 'num_stages': c[3], 'thread_num': c[0] * 2} + for c in _configs + ] + return configs + + @autotune(configs=get_configs(), keys=['block_M', 'block_N', 'block_K', 'num_stages', 'thread_num'], warmup=10, rep=5) + @jit(out_idx=[2], supply_type=tl.TensorSupplyType.Normal, ref_prog=bmm_triton, rtol=0.01, atol=0.01, profiler="tvm") + def kernel(block_M = None, block_N = None, block_K = None, num_stages = None, thread_num = None): + dtype = "float16" + accum_dtype = "float" + nchunks = T.ceildiv(seqlen, chunk_size) + @T.prim_func + def main( + A: T.Buffer((batch, seqlen, ngroups, dstate), dtype), + B: T.Buffer((batch, seqlen, ngroups, dstate), dtype), + Output: T.Buffer((batch, nchunks, ngroups, chunk_size, chunk_size), dtype) + ): + with T.Kernel(T.ceildiv(chunk_size, block_M) * T.ceildiv(chunk_size, block_N), batch, nchunks * ngroups, threads=thread_num) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) + acc_o_shared = T.alloc_shared((block_M, block_N), dtype) + chunk_idx = bz // ngroups + group_idx = bz % ngroups + m_idx = bx // T.ceildiv(chunk_size, block_N) + n_idx = bx % T.ceildiv(chunk_size, block_N) + + loop_range = T.ceildiv(dstate, block_K) + T.clear(acc_o) + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy(A[by, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + group_idx, + k * block_K : (k + 1) * block_K], + A_shared) + T.copy(B[by, + chunk_idx * chunk_size + n_idx * block_N : chunk_idx * chunk_size + (n_idx + 1) * block_N, + group_idx, + k * block_K : (k + 1) * block_K], + B_shared) + T.gemm(A_shared, B_shared, acc_o, transpose_B=True) + T.copy(acc_o, acc_o_shared) + T.copy(acc_o_shared, Output[by, chunk_idx, group_idx, m_idx * block_M : (m_idx + 1) * block_M, n_idx * block_N : (n_idx + 1) * block_N]) + + return main + return kernel() + +#################################################################################################### +# chunk_state +#################################################################################################### + +# def chunk_state_triton(B, x, dt, dA_cumsum): +# from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd +# return _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=False) + +# def chunk_state_fwd(batch, seqlen, ngroups, nheads, headdim, dstate, block_M, block_N, block_K): +# dtype = "float16" +# accum_dtype = "float" +# nchunks = T.ceildiv(seqlen, chunk_size) +# p = 1.44269504 +# @T.prim_func +# def main( +# B: T.Buffer((batch, seqlen, ngroups, dstate), dtype), +# x: T.Buffer((batch, seqlen, nheads, headdim), dtype), +# dt: T.Buffer((batch, nheads, nchunks, chunk_size), dtype), +# dA_cumsum: T.Buffer((batch, nheads, nchunks, chunk_size), dtype), +# Output: T.Buffer((batch, nchunks, nheads, headdim, dstate), dtype) +# ): +# with T.Kernel(T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N), batch * nchunks, nheads, threads=128) as (bx, by, bz): +# x_shared = T.alloc_shared((block_K, block_M), dtype) +# x_local = T.alloc_fragment((block_K, block_M), dtype) +# xt_local = T.alloc_fragment((block_M, block_K), dtype) +# B_shared = T.alloc_shared((block_K, block_N), dtype) +# dt_shared = T.alloc_shared((block_K), dtype) +# dA_cumsum_shared = T.alloc_shared((block_K), dtype) +# acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) +# acc_o_shared = T.alloc_shared((block_M, block_N), dtype) +# scale = T.alloc_fragment((block_K), accum_dtype) +# dA_cs_last = T.alloc_fragment((1), accum_dtype) +# dA_cumsum_local = T.alloc_fragment((block_K), accum_dtype) +# dt_local = T.alloc_fragment((block_K), accum_dtype) + +# loop_range = T.ceildiv(chunk_size, block_K) + +# batch_idx = by % batch +# chunk_idx = by // batch +# m_idx = bx // T.ceildiv(dstate, block_N) +# n_idx = bx % T.ceildiv(dstate, block_N) + +# T.annotate_layout({ +# acc_o_shared: tl.layout.make_swizzled_layout(acc_o_shared) +# }) + +# dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1] +# T.clear(acc_o) +# for k in T.Pipelined( +# loop_range, +# num_stages=4, +# order=[-1,-1,-1,1,-1,0], +# stage=[-1,-1,-1,0,-1,1], +# group=[[0],[1],[2],[3,4,5,6,7],[8],[9]], +# ): +# T.copy(x[batch_idx, +# chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, +# bz, +# m_idx * block_M : (m_idx + 1) * block_M], +# x_shared) +# T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cumsum_shared) +# T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) +# T.copy(dA_cumsum_shared, dA_cumsum_local) +# T.copy(dt_shared, dt_local) +# for i in T.Parallel(block_K): +# scale[i] = T.exp2(dA_cs_last[0] * p - dA_cumsum_local[i] * p) * dt_local[i] +# T.copy(x_shared, x_local) +# for i, j in T.Parallel(block_M, block_K): +# xt_local[i, j] = x_local[j, i] * scale[j] +# T.copy(B[batch_idx, +# chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, +# bz // (nheads // ngroups), +# n_idx * block_N : (n_idx + 1) * block_N], +# B_shared) +# T.gemm(xt_local, B_shared, acc_o) +# T.copy(acc_o, acc_o_shared) +# T.copy(acc_o_shared, Output[batch_idx, chunk_idx, bz, m_idx * block_M : (m_idx + 1) * block_M, n_idx * block_N : (n_idx + 1) * block_N]) +# return main + +# def chunk_state_ref(B, x, dt, dA_cumsum): +# from einops import rearrange, repeat +# """ +# Argument: +# B: (batch, seqlen, ngroups, headdim) +# x: (batch, seqlen, nheads, headdim) +# dt: (batch, nheads, nchunks, chunk_size) +# dA_cumsum: (batch, nheads, nchunks, chunk_size) +# Return: +# states: (batch, nchunks, nheads, headdim, dstate) +# """ +# # Check constraints. +# batch, seqlen, nheads, headdim = x.shape +# dstate = B.shape[-1] +# _, _, nchunks, chunk_size = dt.shape +# assert seqlen <= nchunks * chunk_size +# assert x.shape == (batch, seqlen, nheads, headdim) +# assert dt.shape == (batch, nheads, nchunks, chunk_size) +# ngroups = B.shape[2] +# assert nheads % ngroups == 0 +# assert B.shape == (batch, seqlen, ngroups, dstate) +# B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) +# assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) +# if seqlen < nchunks * chunk_size: +# x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) +# B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) +# x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size) +# B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size) +# decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)) +# return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x) + +def chunk_state(batch, seqlen, ngroups, nheads, headdim, dstate): + + def chunk_state_triton(B, x, dt, dA_cumsum): + from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd + return _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=False) + + def chunk_state_ref(B, x, dt, dA_cumsum): + """ + Argument: + B: (batch, seqlen, ngroups, headdim) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + Return: + states: (batch, nchunks, nheads, headdim, dstate) + """ + # Check constraints. + batch, seqlen, nheads, headdim = x.shape + dstate = B.shape[-1] + _, _, nchunks, chunk_size = dt.shape + assert seqlen <= nchunks * chunk_size + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + ngroups = B.shape[2] + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if seqlen < nchunks * chunk_size: + x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size) + B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size) + decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)) + return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x) + + def get_configs(): + # block_M = [64, 128] + # block_N = [32, 64, 128] + # block_K = [32, 64] + # num_stages = [2,3,4,5] + block_M = [64] + block_N = [128] + block_K = [64] + num_stages = [4] + _configs = list(itertools.product(block_M, block_N, block_K, num_stages)) + + configs = [ + {'block_M': c[0], 'block_N': c[1], 'block_K': c[2], 'num_stages': c[3], 'thread_num': c[0] * 2} + for c in _configs + ] + return configs + + @autotune(configs=get_configs(), keys=['block_M', 'block_N', 'block_K', 'num_stages', 'thread_num'], warmup=10, rep=5) + @jit(out_idx=[4], supply_type=tl.TensorSupplyType.Normal, ref_prog=chunk_state_triton, check_close=False, rtol=0.01, atol=0.01, profiler="tvm") + def kernel(block_M = None, block_N = None, block_K = None, num_stages = None, thread_num = None): + dtype = "float16" + accum_dtype = "float" + nchunks = T.ceildiv(seqlen, chunk_size) + p = 1.44269504 + @T.prim_func + def main( + B: T.Buffer((batch, seqlen, ngroups, dstate), dtype), + x: T.Buffer((batch, seqlen, nheads, headdim), dtype), + dt: T.Buffer((batch, nheads, nchunks, chunk_size), dtype), + dA_cumsum: T.Buffer((batch, nheads, nchunks, chunk_size), dtype), + Output: T.Buffer((batch, nchunks, nheads, headdim, dstate), dtype) + ): + with T.Kernel(T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N), batch * nchunks, nheads, threads=thread_num) as (bx, by, bz): + x_shared = T.alloc_shared((block_K, block_M), dtype) + x_local = T.alloc_fragment((block_K, block_M), dtype) + xt_local = T.alloc_fragment((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + dt_shared = T.alloc_shared((block_K), dtype) + dA_cumsum_shared = T.alloc_shared((block_K), dtype) + acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) + acc_o_shared = T.alloc_shared((block_M, block_N), dtype) + scale = T.alloc_fragment((block_K), accum_dtype) + dA_cs_last = T.alloc_fragment((1), accum_dtype) + dA_cumsum_local = T.alloc_fragment((block_K), accum_dtype) + dt_local = T.alloc_fragment((block_K), accum_dtype) + + loop_range = T.ceildiv(chunk_size, block_K) + + batch_idx = by % batch + chunk_idx = by // batch + m_idx = bx // T.ceildiv(dstate, block_N) + n_idx = bx % T.ceildiv(dstate, block_N) + dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1] + + T.annotate_layout({ + x_shared: tl.layout.make_swizzled_layout(x_shared), + acc_o_shared: tl.layout.make_swizzled_layout(acc_o_shared) + }) + + T.clear(acc_o) + for k in T.Pipelined( + loop_range, + num_stages=num_stages, + order=[-1,1,-1,2,-1,3,-1,0], + stage=[-1,0,-1,0,-1,0,-1,1], + group=[[0],[1],[2],[3],[4],[5,6,7],[8],[9]], + # order=[-1,-1,-1,1,-1,0], + # stage=[-1,-1,-1,0,-1,1], + # group=[[0],[1],[2],[3,4,5,6,7],[8],[9]], + ): + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cumsum_shared) + T.copy(dA_cumsum_shared, dA_cumsum_local) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) + T.copy(dt_shared, dt_local) + T.copy(x[batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + m_idx * block_M : (m_idx + 1) * block_M], + x_shared) + T.copy(x_shared, x_local) + for i in T.Parallel(block_K): + scale[i] = T.exp2(dA_cs_last[0] * p - dA_cumsum_local[i] * p) * dt_local[i] + for i, j in T.Parallel(block_M, block_K): + xt_local[i, j] = x_local[j, i] * scale[j] + T.copy(B[batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz // (nheads // ngroups), + n_idx * block_N : (n_idx + 1) * block_N], + B_shared) + T.gemm(xt_local, B_shared, acc_o) + T.copy(acc_o, acc_o_shared) + T.copy(acc_o_shared, Output[batch_idx, chunk_idx, bz, m_idx * block_M : (m_idx + 1) * block_M, n_idx * block_N : (n_idx + 1) * block_N]) + return main + return kernel() + +#################################################################################################### +# chunk_scan +#################################################################################################### + +def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states): + from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd + out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states) + return out + +def chunk_scan_ref(cb, x, dt, dA_cumsum, C, prev_states): + from einops import rearrange, repeat + """ + Argument: + cb: (batch, nchunks, ngroups, chunk_size, chunk_size) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + C: (batch, seqlen, ngroups, dstate) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + _, _, ngroups, _, _ = cb.shape + batch, seqlen, nheads, headdim = x.shape + # _, _, ngroups, dstate = B.shape + # assert B.shape == (batch, seqlen, ngroups, dstate) + _, _, nchunks, chunk_size = dt.shape + assert seqlen == nchunks * chunk_size + # assert C.shape == B.shape + # B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) + cb = repeat(cb, "b c g l s -> b c (g h) l s", h=nheads // ngroups) + # CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + # rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) + # (batch, nheads, nchunks, chunksize, chunksize) + dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] + decay = torch.exp(dt_segment_sum) + scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + scores_decay = scores_decay.masked_fill(~causal_mask, 0) + out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), + rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) + state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) + out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + prev_states.to(C.dtype)) * state_decay_out + out = out + out_prev + out = rearrange(out, "b c l h p -> b (c l) h p") + + return out + +def chunk_scan_fwd(batch, seqlen, ngroups, nheads, headdim, dstate, block_M, block_N, block_K, block_Dstate): dtype = "float16" accum_dtype = "float" nchunks = T.ceildiv(seqlen, chunk_size) + p = 1.44269504 @T.prim_func def main( - A: T.Buffer((batch, seqlen, ngroups, k), dtype), - B: T.Buffer((batch, seqlen, ngroups, k), dtype), - Output: T.Buffer((batch, nchunks, ngroups, chunk_size, chunk_size), dtype) + cb: T.Buffer((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), + x: T.Buffer((batch, seqlen, nheads, headdim), dtype), + dt: T.Buffer((batch, nheads, nchunks, chunk_size), dtype), + dA_cumsum: T.Buffer((batch, nheads, nchunks, chunk_size), dtype), + C: T.Buffer((batch, seqlen, ngroups, dstate), dtype), + prev_states: T.Buffer((batch, nchunks, nheads, headdim, dstate), dtype), + Output: T.Buffer((batch, seqlen, nheads, headdim), dtype) ): - with T.Kernel(T.ceildiv(chunk_size, block_M) * T.ceildiv(chunk_size, block_N), batch, nchunks * ngroups, threads=128) as (bx, by, bz): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_N, block_K), dtype) + with T.Kernel(T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, nheads, threads=128) as (bx, by, bz): acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) - chunk_idx = bz // ngroups - group_idx = bz % ngroups - m_idx = bx // T.ceildiv(chunk_size, block_N) - n_idx = bx % T.ceildiv(chunk_size, block_N) + # acc_o_shared = T.alloc_shared((block_M, block_N), dtype) + cb_shared = T.alloc_shared((block_M, block_K), dtype) + cb_local = T.alloc_fragment((block_M, block_K), dtype) + dA_cs_k_shared = T.alloc_shared((block_M), dtype) + dA_cs_k_local = T.alloc_fragment((block_M), dtype) + dA_cs_m_shared = T.alloc_shared((block_M), dtype) + dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype) + dt_shared = T.alloc_shared((block_K), dtype) + dt_local = T.alloc_fragment((block_K), accum_dtype) + x_shared = T.alloc_shared((block_K, block_N), dtype) + scale_m_local = T.alloc_fragment((block_M), accum_dtype) + C_shared = T.alloc_shared((block_M, block_Dstate), dtype) + prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype) - loop_range = T.ceildiv(chunk_size, block_K) + + batch_idx = by % batch + chunk_idx = by // batch + # m: chunk_size + # n : headdim + m_idx = bx // T.ceildiv(headdim, block_N) + n_idx = bx % T.ceildiv(headdim, block_N) + + # T.annotate_layout({ + # acc_o_shared: tl.layout.make_swizzled_layout(acc_o_shared) + # }) + + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared) + T.copy(dA_cs_m_shared, dA_cs_m_local) T.clear(acc_o) + + for i in T.Parallel(block_M): + scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) + T.copy( + C[batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz // (nheads // ngroups), + 0 : block_Dstate + ], + C_shared + ) + T.copy( + prev_states[batch_idx, + chunk_idx, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + 0 : block_Dstate + ], + prev_state_shared + ) + T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) + for i, j in T.Parallel(block_M, block_N): + acc_o[i, j] *= scale_m_local[i] + + loop_range = T.ceildiv((m_idx + 1) * block_M, block_K) + for k in T.Pipelined(loop_range, num_stages=1): - T.copy(A[by, - chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, - group_idx, - k * block_K : (k + 1) * block_K], - A_shared) - T.copy(B[by, - chunk_idx * chunk_size + n_idx * block_N : chunk_idx * chunk_size + (n_idx + 1) * block_N, - group_idx, - k * block_K : (k + 1) * block_K], - B_shared) - T.gemm(A_shared, B_shared, acc_o, transpose_B=True) - T.copy(acc_o, Output[by, chunk_idx, group_idx, m_idx * block_M : (m_idx + 1) * block_M, n_idx * block_N : (n_idx + 1) * block_N]) + T.copy( + cb[batch_idx, + chunk_idx, + bz // (nheads // ngroups), + m_idx * block_M : (m_idx + 1) * block_M, + k * block_K : (k + 1) * block_K], + cb_shared + ) + T.copy(cb_shared, cb_local) + T.copy( + dA_cumsum[batch_idx, + bz, + chunk_idx, + k * block_K : (k + 1) * block_K], + dA_cs_k_shared + ) + T.copy(dA_cs_k_shared, dA_cs_k_local) + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) + T.copy(dt_shared, dt_local) + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] *= dt_local[j] + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] = T.if_then_else( + m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0 + ) + T.copy(x[batch_idx, chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, bz, n_idx * block_N : (n_idx + 1) * block_N], x_shared) + T.gemm(cb_local, x_shared, acc_o) + # T.copy(acc_o, acc_o_shared) + T.copy(acc_o, Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, bz, n_idx * block_N : (n_idx + 1) * block_N]) return main -def ref_program(A, B): - from einops import rearrange, repeat - seqlen = A.shape[1] - nchunks = (seqlen + chunk_size - 1) // chunk_size +def bmm_chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): + """ + Argument: + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert B.shape == (batch, seqlen, ngroups, dstate) + _, _, nchunks, chunk_size = dt.shape + assert seqlen == nchunks * chunk_size + assert C.shape == B.shape + B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) + CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) + # (batch, nheads, nchunks, chunksize, chunksize) + dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] + decay = torch.exp(dt_segment_sum) + scores_decay = CB * rearrange(decay, "b h c l s -> b c h l s") + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + scores_decay = scores_decay.masked_fill(~causal_mask, 0) + out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), + rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) + state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) + out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + prev_states.to(C.dtype)) * state_decay_out + out = out + out_prev + out = rearrange(out, "b c l h p -> b (c l) h p") + if D is not None: + if D.dim() == 1: + D = rearrange(D, "h -> h 1") + out = out + x * D + return out if z is None else out * F.silu(z) + +def bmm_chunk_scan_fwd(batch, seqlen, ngroups, nheads, headdim, dstate, block_M, block_N, block_K, block_Dstate): + dtype = "float16" + accum_dtype = "float" + nchunks = T.ceildiv(seqlen, chunk_size) + p = 1.44269504 + @T.prim_func + def main( + B: T.Buffer((batch, seqlen, ngroups, dstate), dtype), + x: T.Buffer((batch, seqlen, nheads, headdim), dtype), + dt: T.Buffer((batch, nheads, nchunks, chunk_size), dtype), + dA_cumsum: T.Buffer((batch, nheads, nchunks, chunk_size), dtype), + C: T.Buffer((batch, seqlen, ngroups, dstate), dtype), + prev_states: T.Buffer((batch, nchunks, nheads, headdim, dstate), dtype), + Output: T.Buffer((batch, seqlen, nheads, headdim), dtype) + ): + with T.Kernel(T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, nheads, threads=128) as (bx, by, bz): + acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) + acc_o_shared = T.alloc_shared((block_M, block_N), dtype) + cb_shared = T.alloc_shared((block_M, block_K), dtype) + cb_local = T.alloc_fragment((block_M, block_K), dtype) + dA_cs_k_shared = T.alloc_shared((block_M), dtype) + dA_cs_k_local = T.alloc_fragment((block_M), dtype) + dA_cs_m_shared = T.alloc_shared((block_M), dtype) + dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype) + dt_shared = T.alloc_shared((block_K), dtype) + dt_local = T.alloc_fragment((block_K), accum_dtype) + x_shared = T.alloc_shared((block_K, block_N), dtype) + scale_m_local = T.alloc_fragment((block_M), accum_dtype) + C_shared = T.alloc_shared((block_M, block_Dstate), dtype) + prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype) + + + batch_idx = by % batch + chunk_idx = by // batch + # m: chunk_size + # n : headdim + m_idx = bx // T.ceildiv(headdim, block_N) + n_idx = bx % T.ceildiv(headdim, block_N) + + # T.annotate_layout({ + # acc_o_shared: tl.layout.make_swizzled_layout(acc_o_shared) + # }) + + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared) + T.copy(dA_cs_m_shared, dA_cs_m_local) + T.clear(acc_o) + + for i in T.Parallel(block_M): + scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) + T.copy( + C[batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz // (nheads // ngroups), + 0 : block_Dstate + ], + C_shared + ) + T.copy( + prev_states[batch_idx, + chunk_idx, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + 0 : block_Dstate + ], + prev_state_shared + ) + T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) + for i, j in T.Parallel(block_M, block_N): + acc_o[i, j] *= scale_m_local[i] + + loop_range = T.ceildiv((m_idx + 1) * block_M, block_K) + + for k in T.Pipelined(loop_range, num_stages=4): + T.copy( + cb[batch_idx, + chunk_idx, + bz // (nheads // ngroups), + m_idx * block_M : (m_idx + 1) * block_M, + k * block_K : (k + 1) * block_K], + cb_shared + ) + T.copy(cb_shared, cb_local) + T.copy( + dA_cumsum[batch_idx, + bz, + chunk_idx, + k * block_K : (k + 1) * block_K], + dA_cs_k_shared + ) + T.copy(dA_cs_k_shared, dA_cs_k_local) + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) + T.copy(dt_shared, dt_local) + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] *= dt_local[j] + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] = T.if_then_else( + m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0 + ) + T.copy(x[batch_idx, chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, bz, n_idx * block_N : (n_idx + 1) * block_N], x_shared) + T.gemm(cb_local, x_shared, acc_o) + T.copy(acc_o, acc_o_shared) + T.copy(acc_o_shared, Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, bz, n_idx * block_N : (n_idx + 1) * block_N]) + + return main + +# def chunk_scan_fwd(batch, seqlen, ngroups, nheads, headdim, dstate): + +# def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states): +# from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd +# out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states) +# return out + +# def get_configs(): +# # block_M = [64, 128] +# # block_N = [32, 64, 128] +# # block_K = [32, 64] +# # block_Dstate = [128] +# # num_stages = [2,3,4,5] +# block_M = [64] +# block_N = [64] +# block_K = [64] +# block_Dstate = [128] +# num_stages = [4] +# _configs = list(itertools.product(block_M, block_N, block_K, block_Dstate, num_stages)) + +# configs = [ +# {'block_M': c[0], 'block_N': c[1], 'block_K': c[2], 'block_Dstate': c[3], 'num_stages': c[4], 'thread_num': c[0] * 2} +# for c in _configs +# ] +# return configs + +# @autotune(configs=get_configs(), keys=['block_M', 'block_N', 'block_K', 'block_Dstate', 'num_stages', 'thread_num'], warmup=10, rep=5) +# @jit(out_idx=[6], supply_type=tl.TensorSupplyType.Normal, ref_prog=chunk_scan_triton, check_close=False, rtol=0.01, atol=0.01, profiler="tvm") +# def kernel(block_M = None, block_N = None, block_K = None, block_Dstate=None, num_stages = None, thread_num = None): +# dtype = "float16" +# accum_dtype = "float" +# nchunks = T.ceildiv(seqlen, chunk_size) +# p = 1.44269504 +# @T.prim_func +# def main( +# cb: T.Buffer((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), +# x: T.Buffer((batch, seqlen, nheads, headdim), dtype), +# dt: T.Buffer((batch, nheads, nchunks, chunk_size), dtype), +# dA_cumsum: T.Buffer((batch, nheads, nchunks, chunk_size), dtype), +# C: T.Buffer((batch, seqlen, ngroups, dstate), dtype), +# prev_states: T.Buffer((batch, nchunks, nheads, headdim, dstate), dtype), +# Output: T.Buffer((batch, seqlen, nheads, headdim), dtype) +# ): +# with T.Kernel(T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, nheads, threads=thread_num) as (bx, by, bz): +# acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) +# acc_o_shared = T.alloc_shared((block_M, block_N), dtype) +# cb_shared = T.alloc_shared((block_M, block_K), dtype) +# cb_local = T.alloc_fragment((block_M, block_K), dtype) +# dA_cs_k_shared = T.alloc_shared((block_M), dtype) +# dA_cs_k_local = T.alloc_fragment((block_M), dtype) +# dA_cs_m_shared = T.alloc_shared((block_M), dtype) +# dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype) +# dt_shared = T.alloc_shared((block_K), dtype) +# dt_local = T.alloc_fragment((block_K), accum_dtype) +# x_shared = T.alloc_shared((block_K, block_N), dtype) +# scale_m_local = T.alloc_fragment((block_M), accum_dtype) +# C_shared = T.alloc_shared((block_M, block_Dstate), dtype) +# prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype) + + +# batch_idx = by % batch +# chunk_idx = by // batch +# # m: chunk_size +# # n : headdim +# m_idx = bx // T.ceildiv(headdim, block_N) +# n_idx = bx % T.ceildiv(headdim, block_N) + +# T.annotate_layout({ +# acc_o_shared: tl.layout.make_swizzled_layout(acc_o_shared) +# }) + +# T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared) +# T.copy(dA_cs_m_shared, dA_cs_m_local) +# T.clear(acc_o) + +# for i in T.Parallel(block_M): +# scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) +# T.copy( +# C[batch_idx, +# chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, +# bz // (nheads // ngroups), +# 0 : block_Dstate +# ], +# C_shared +# ) +# T.copy( +# prev_states[batch_idx, +# chunk_idx, +# bz, +# n_idx * block_N : (n_idx + 1) * block_N, +# 0 : block_Dstate +# ], +# prev_state_shared +# ) +# T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) +# for i, j in T.Parallel(block_M, block_N): +# acc_o[i, j] *= scale_m_local[i] + +# loop_range = T.ceildiv((m_idx + 1) * block_M, block_K) + +# for k in T.Pipelined(loop_range, num_stages=num_stages): +# T.copy( +# cb[batch_idx, +# chunk_idx, +# bz // (nheads // ngroups), +# m_idx * block_M : (m_idx + 1) * block_M, +# k * block_K : (k + 1) * block_K], +# cb_shared +# ) +# T.copy(cb_shared, cb_local) +# T.copy( +# dA_cumsum[batch_idx, +# bz, +# chunk_idx, +# k * block_K : (k + 1) * block_K], +# dA_cs_k_shared +# ) +# T.copy(dA_cs_k_shared, dA_cs_k_local) +# for i, j in T.Parallel(block_M, block_K): +# cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) +# T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) +# T.copy(dt_shared, dt_local) +# for i, j in T.Parallel(block_M, block_K): +# cb_local[i, j] *= dt_local[j] +# for i, j in T.Parallel(block_M, block_K): +# cb_local[i, j] = T.if_then_else( +# m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0 +# ) +# T.copy(x[batch_idx, chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, bz, n_idx * block_N : (n_idx + 1) * block_N], x_shared) +# T.gemm(cb_local, x_shared, acc_o) +# T.copy(acc_o, acc_o_shared) +# T.copy(acc_o_shared, Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, bz, n_idx * block_N : (n_idx + 1) * block_N]) + +# return main +# return kernel() + +def state_passing_fwd(batch, seqlen, nheads, headdim, block_M): + dtype = "float16" + accum_dtype = "float" + nchunks = T.ceildiv(seqlen, chunk_size) + p = 1.44269504 + @T.prim_func + def main( + states: T.Buffer((batch, nchunks, nheads, headdim), dtype), + dA_chunk_cumsum: T.Buffer((batch, nheads, nchunks), dtype), + initial_states: T.Buffer((batch, nheads, headdim), dtype), + Output: T.Buffer((batch, nchunks + 1, nheads, headdim), dtype), + ): + with T.Kernel(T.ceildiv(headdim, block_M), batch, nheads, threads=128) as (bx, by, bz): + # state_shared = T.alloc_shared((block_M), dtype) + dA_cs_local = T.alloc_fragment((1,1), accum_dtype) + scale = T.alloc_fragment((1,1), accum_dtype) + state_local = T.alloc_fragment((block_M), accum_dtype) + new_state_local = T.alloc_fragment((block_M), accum_dtype) + + T.annotate_layout({ + dA_cs_local: tl.layout.make_swizzled_layout(dA_cs_local), + }) + + batch_idx = by + head_idx = bz + m_idx = bx + + T.clear(state_local) + T.copy(initial_states[batch_idx, head_idx, m_idx * block_M : (m_idx + 1) * block_M], state_local) + T.copy(state_local, Output[batch_idx, 0, head_idx, m_idx * block_M : (m_idx + 1) * block_M]) + # T.copy(state_shared, state_local) + for k in T.Pipelined(nchunks, num_stages=1): + # T.copy(states[batch_idx, k, head_idx, m_idx * block_M : (m_idx + 1) * block_M], state_shared) + # T.copy(state_shared, new_state_local) + for i in T.Parallel(block_M): + new_state_local[i] = states[batch_idx, k, head_idx, m_idx * block_M + i] + dA_cs_local[0,0] = dA_chunk_cumsum[batch_idx, head_idx, k] + scale[0,0] = T.exp2(dA_cs_local[0,0] * p) + for i in T.Parallel(block_M): + state_local[i] = state_local[i] * scale[0,0] + new_state_local[i] + T.copy(state_local, Output[batch_idx, k + 1, head_idx, m_idx * block_M : (m_idx + 1) * block_M]) + + return main + +def state_passing_ref(states, dA_chunk_cumsum, initial_states): + """ + Argument: + states: (batch, nchunks, nheads, dim) + dA_chunk_cumsum: (batch, nheads, nchunks) + initial_states: (batch, nheads, dim) + Return: + out: (batch, nchunks, nheads, dim) + final_states: (batch, nheads, dim) + """ + if initial_states is None: + initial_states = torch.zeros_like(states[:, 0]) + states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1) + dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0)) + dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1) + nchunks = dA_chunk_cumsum.shape[-1] + # (batch, nheads, nchunks, nchunks) + dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :] + # (batch, nheads, nchunks, nchunks) + decay_chunk = torch.exp(dt_chunk_segment_sum) + causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0) + decay_chunk = decay_chunk.masked_fill(~causal_mask, 0) + out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states) + return out + +def selective_scan_update_fwd(batch, seqlen, nheads, ngroups, headdim, dstate, block_M, block_Dstate): + dtype = "float16" + accum_dtype = "float" + nchunks = T.ceildiv(seqlen, chunk_size) + p = 1.44269504 + assert dstate == block_Dstate + @T.prim_func + def main( + state: T.Buffer((batch, nheads, headdim, dstate), dtype), + x: T.Buffer((batch, nheads, headdim), dtype), + dt: T.Buffer((batch, nheads, headdim), dtype), + A: T.Buffer((nheads, headdim, dstate), dtype), + B: T.Buffer((batch, ngroups, dstate), dtype), + C: T.Buffer((batch, ngroups, dstate), dtype), + Output: T.Buffer((batch, nheads, headdim), dtype) + ): + with T.Kernel(T.ceildiv(headdim, block_M), batch, nheads, threads=128) as (bx, by, bz): + state_shared = T.alloc_shared((block_M, block_Dstate), dtype) + state_local = T.alloc_fragment((block_M, block_Dstate), accum_dtype) + # new_state_local = T.alloc_fragment((block_M, block_Dstate), accum_dtype) + x_shared = T.alloc_shared((block_M), dtype) + x_local = T.alloc_fragment((block_M), accum_dtype) + dt_shared = T.alloc_shared((block_M), dtype) + dt_local = T.alloc_fragment((block_M), accum_dtype) + A_shared = T.alloc_shared((block_M, block_Dstate), dtype) + A_local = T.alloc_fragment((block_M, block_Dstate), accum_dtype) + dA_local = T.alloc_fragment((block_M, block_Dstate), accum_dtype) + B_shared = T.alloc_shared((block_Dstate), dtype) + C_shared = T.alloc_shared((block_Dstate), dtype) + C_local = T.alloc_fragment((block_Dstate), accum_dtype) + B_local = T.alloc_fragment((block_Dstate), accum_dtype) + dB_local = T.alloc_fragment((block_M, block_Dstate), accum_dtype) + state_sum_local = T.alloc_fragment((block_M), accum_dtype) + + batch_idx = by + head_idx = bz + m_idx = bx - A = rearrange(A, "b (c l) g d -> b c l g d", c=nchunks) - B = rearrange(B, "b (c l) g d -> b c l g d", c=nchunks) + # T.annotate_layout({ + # new_state_local: tl.layout.make_swizzled_layout(state_shared), + # }) + + T.copy(state[batch_idx, head_idx, m_idx * block_M : (m_idx + 1) * block_M, :], state_shared) + T.copy(state_shared, state_local) + T.copy(x[batch_idx, head_idx, m_idx * block_M : (m_idx + 1) * block_M], x_shared) + T.copy(x_shared, x_local) + # Not TIE_HDIM + T.copy(dt[batch_idx, head_idx, m_idx * block_M : (m_idx + 1) * block_M], dt_shared) + T.copy(dt_shared, dt_local) + T.copy(A[head_idx, m_idx * block_M : (m_idx + 1) * block_M, :], A_shared) + T.copy(A_shared, A_local) + for i, j in T.Parallel(block_M, block_Dstate): + dA_local[i, j] = T.exp2(A_local[i, j] * dt_local[i] * p) + T.copy(B[batch_idx, bz // (nheads // ngroups), :], B_shared) + T.copy(B_shared, B_local) + T.copy(C[batch_idx, bz // (nheads // ngroups), :], C_shared) + T.copy(C_shared, C_local) + for i, j in T.Parallel(block_M, block_Dstate): + dB_local[i, j] = B_local[j] * dt_local[i] + for i, j in T.Parallel(block_M, block_Dstate): + state_local[i, j] *= dA_local[i, j] + for i, j in T.Parallel(block_M, block_Dstate): + state_local[i, j] += dB_local[i, j] * x_local[i] + for i, j in T.Parallel(block_M, block_Dstate): + state_local[i, j] *= C_local[j] + T.reduce_sum(state_local, state_sum_local, dim=1) + T.copy(state_sum_local, Output[batch_idx, head_idx, m_idx * block_M : (m_idx + 1) * block_M]) + + return main - return torch.einsum("bclgd,bcsgd->bcgls", A, B) +def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + dt = dt + dt_bias + dt = F.softplus(dt) if dt_softplus else dt + dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate) + B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) + C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) + dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) + state_ = state * dA + dB * rearrange(x, "b h d -> b h d 1") # (batch, dim, dstate + out = torch.einsum("bhdn,bhn->bhd", state_.to(C.dtype), C) + if D is not None: + out += (x * D).to(out.dtype) + out = (out if z is None else out * F.silu(z)).to(x.dtype) + if not has_heads: + out = out.squeeze(1) + return out if __name__ == "__main__": - BATCH, SEQLEN, NGROUPS, DSTATE = 8, 4096, 16, 64 - block_M, block_N, block_K = 64, 64, 64 - program = bmm_chunk(BATCH, SEQLEN, NGROUPS, DSTATE, block_M, block_N, block_K) + BATCH, NHEADS, NGROUPS, SEQLEN, HEADDIM, DSTATE = 8, 80, 1, 8192, 64, 128 + # BATCH, NHEADS, NGROUPS, SEQLEN, HEADDIM, DSTATE = 1, 1, 1, 256, 64, 128 + block_M, block_N, block_K, block_Dstate = 64, 64, 64, 128 + # chunk_cumsum_fwd + + # state_passing_fwd + # BATCH, SEQLEN, NHEADS, HEADDIM = 4, 2048, 8, 64 + # block_M = 64 + # program = state_passing_fwd(BATCH, SEQLEN, NHEADS, HEADDIM, block_M) + # mod, params = tl.lower(program) + # mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) + # mod.assert_allclose(state_passing_ref, rtol=0.01, atol=0.01) + + + # chunk_state_fwd + # total_flops = 2 * BATCH * SEQLEN * NHEADS * HEADDIM * DSTATE + # best_latency, best_config, ref_latency = chunk_state(BATCH, SEQLEN, NGROUPS, NHEADS, HEADDIM, DSTATE) + # print(f"Best latency: {best_latency}") + # print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + # print(f"Best config: {best_config}") + # print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}") + # program = chunk_state_fwd(BATCH, SEQLEN, NGROUPS, NHEADS, HEADDIM, DSTATE, block_M, block_N, block_K) + # mod, params = tl.lower(program) + # mod = tl.Profiler(mod, params, [4], tl.TensorSupplyType.Normal) + # # mod.assert_allclose(chunk_state_triton, rtol=0.01, atol=0.01) + # latency = mod.do_bench(chunk_state_triton, n_warmup=10, n_repeat=10, profiler="torch") + # print("{:.2f} ms".format(latency)) + # print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) + # latency = mod.do_bench(mod, n_warmup=10, n_repeat=10, profiler="tvm") + # print("{:.2f} ms".format(latency)) + # print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) + + + # bmm_chunk + # total_flops = 2 * BATCH * SEQLEN * NGROUPS * DSTATE * chunk_size + # best_latency, best_config, ref_latency = bmm_chunk(BATCH, SEQLEN, NGROUPS, DSTATE) + # print(f"Best latency: {best_latency}") + # print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + # print(f"Best config: {best_config}") + # print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}") + # program = bmm_chunk(BATCH, SEQLEN, NGROUPS, DSTATE, block_M, block_N, block_K, 2, 128) + # mod, params = tl.lower(program) + # mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Normal) + # mod.assert_allclose(bmm_triton, rtol=0.1, atol=0.1) + # total_flops = 2 * BATCH * SEQLEN * NGROUPS * DSTATE * chunk_size + # latency = mod.do_bench(bmm_triton, n_warmup=10, n_repeat=10, profiler="tvm") + # print("{:.2f} ms".format(latency)) + # print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) + # latency = mod.do_bench(mod, n_warmup=10, n_repeat=10, profiler="tvm") + # print("{:.2f} ms".format(latency)) + # print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) + + # chunk_scan_fwd + total_flops = 2.0 * BATCH * SEQLEN * chunk_size * NHEADS * HEADDIM * 0.5 + 2.0 * BATCH * SEQLEN * NHEADS * HEADDIM * DSTATE + # best_latency, best_config, ref_latency = chunk_scan_fwd(BATCH, SEQLEN, NGROUPS, NHEADS, HEADDIM, DSTATE) + # print(f"Best latency: {best_latency}") + # print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + # print(f"Best config: {best_config}") + # print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}") + program = chunk_scan_fwd(BATCH, SEQLEN, NGROUPS, NHEADS, HEADDIM, DSTATE, block_M, block_N, block_K, block_Dstate) mod, params = tl.lower(program) - mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Normal) - mod.assert_allclose(ref_program, rtol=0.1, atol=0.1) \ No newline at end of file + mod = tl.Profiler(mod, params, [6], tl.TensorSupplyType.Normal) + mod.assert_allclose(chunk_scan_ref, rtol=0.01, atol=0.01) + latency = mod.do_bench(chunk_scan_ref, n_warmup=10, n_repeat=10, profiler="torch") + print("{:.2f} ms".format(latency)) + print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = mod.do_bench(mod, n_warmup=10, n_repeat=10, profiler="tvm") + print("{:.2f} ms".format(latency)) + print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) + + # selective_state_update_fwd + # BATCH, SEQLEN, NHEADS, NGROUPS, HEADDIM, DSTATE = 1, 4096, 1, 1, 64, 64 + # block_M, block_Dstate = 64, 64 + # program = selective_scan_update_fwd(BATCH, SEQLEN, NHEADS, NGROUPS, HEADDIM, DSTATE, block_M, block_Dstate) + # mod, params = tl.lower(program) + # mod = tl.Profiler(mod, params, [6], tl.TensorSupplyType.Normal) + # mod.assert_allclose(selective_state_update_ref, rtol=0.1, atol=0.1) \ No newline at end of file diff --git a/tl_scripts/mha_test.py b/tl_scripts/mha_test.py deleted file mode 100644 index e52733ad3be7..000000000000 --- a/tl_scripts/mha_test.py +++ /dev/null @@ -1,191 +0,0 @@ -import torch -from tvm import tl -import tvm.tl.language as T -from functools import partial - -# This script gives a wrong result when dim=64. -# The error is due to the acc_s_cast tensor reuse the register of Q_local tensor (don't know why). -# It is a strange error because in PTX file, the register of Q_local and acc_s_cast are different. -# To reproduce the error, you can try the following script: -# with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): -# Q_shared = T.alloc_shared([block_M, dim], dtype) -# Q_local = T.alloc_fragment([block_M, dim], dtype) -# K_shared = T.alloc_shared([block_N, dim], dtype) -# V_shared = T.alloc_shared([block_N, dim], dtype) -# acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) -# acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) -# acc_o = T.alloc_fragment([block_M, dim], accum_dtype) - -# T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) -# T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) -# T.fill(acc_o, 0) -# T.copy(Q_shared, Q_local) -# for i, j in T.Parallel(block_M, dim): -# Q_local[i, j] *= scale -# loop_range = ( -# T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N) -# ) -# for k in T.Pipelined(loop_range, num_stages=1): -# T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) -# if is_casual: -# for i, j in T.Parallel(block_M, block_N): -# acc_s[i, j] = T.if_then_else( -# bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype) -# ) -# else: -# T.clear(acc_s) -# T.gemm(Q_local, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) -# T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) -# for i, j in T.Parallel(block_M, block_N): -# acc_s[i, j] = T.exp2(acc_s[i, j] - 32) -# T.copy(acc_s, acc_s_cast) -# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) -# T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) - -# To fix this, we can either use T.gemm(Q_shared, K_shared, acc_s), like in FlashAttention implementation, -# or use different wgmma instrutcion (like M64N32K16) - -def flashattn(batch, heads, seq_len, dim, is_casual, block_M, block_N): - scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) - shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" - - @T.prim_func - def main( - Q: T.Buffer(shape, dtype), - K: T.Buffer(shape, dtype), - V: T.Buffer(shape, dtype), - Output: T.Buffer(shape, dtype), - ): - with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): - Q_shared = T.alloc_shared([block_M, dim], dtype) - Q_local = T.alloc_fragment([block_M, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim], dtype) - acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) - acc_o = T.alloc_fragment([block_M, dim], accum_dtype) - scores_max = T.alloc_fragment([block_M], accum_dtype) - scores_max_prev = T.alloc_fragment([block_M], accum_dtype) - scores_scale = T.alloc_fragment([block_M], accum_dtype) - scores_sum = T.alloc_fragment([block_M], accum_dtype) - logsum = T.alloc_fragment([block_M], accum_dtype) - - T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) - T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.copy(Q_shared, Q_local) - for i, j in T.Parallel(block_M, dim): - Q_local[i, j] *= scale - loop_range = ( - T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N) - ) - for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) - if is_casual: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else( - bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype) - ) - else: - T.clear(acc_s) - T.gemm(Q_local, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) - T.copy(acc_s, acc_s_cast) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) - - return main - - -def ref_program(Q, K, V, casual): - # from flash_attn.flash_attn_interface import flash_attn_func - - # return flash_attn_func(Q, K, V, causal=casual) - assert casual == False, "casual is not supported" - batch, seq_len, heads, dim = Q.size() - scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) - block_M = seq_len - block_N = 64 if dim <= 128 else 32 - acc_s = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float) - acc_s_cast = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float16) - acc_o = torch.empty((batch, block_M, heads, dim), device="cuda", dtype=torch.float) - scores_max = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) - scores_max_prev = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) - scores_scale = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) - scores_sum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) - logsum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) - acc_o.fill_(0) - logsum.fill_(0) - scores_max.fill_(float('-inf')) - Q_scaled = Q * scale - - for i in range(int(seq_len / block_N)): - acc_s.fill_(0) - acc_s = torch.einsum('bqhd,bkhd->bhqk', Q_scaled, K[:, i * block_N : (i + 1) * block_N, :, :]) # [batch, seqlen, heads, block_N] - scores_max_prev = scores_max - scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] - scores_scale = torch.exp2(scores_max_prev - scores_max) - acc_o *= scores_scale[:, :, :, None].transpose(1, 2) - acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) - acc_s_cast = acc_s.to(torch.float16) - acc_o += torch.einsum('bhqk,bkhd->bqhd', acc_s_cast, V[:, i * block_N : (i + 1) * block_N, :, :]) - scores_sum = acc_s.sum(dim=-1, keepdim=False) - logsum = logsum * scores_scale + scores_sum - acc_o /= logsum[:, :, :, None].transpose(1, 2) - return acc_o.to(torch.float16) - -# def ref_program(Q, K, V, casual): -# dim = Q.size(-1) -# scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) - -# # Step 2: Scale the scores by the square root of dim -# scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) - -# # Step 3: Apply softmax to get the attention weights -# attention_weights = F.softmax(scores, dim=-1) - -# # Step 4: Multiply the attention weights by the values (V) -# # This gives us the final output of shape [batch, seq_len, heads, dim] -# output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) - -# return output - -if __name__ == "__main__": - BATCH, H, N_CTX, D_HEAD = 1, 1, 64, 64 - casual = False - flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD - total_flops = 2 * flops_per_matmul - if casual: - total_flops *= 0.5 - BLOCK_M = 64 - BLOCK_N = 64 if D_HEAD <= 128 else 32 - program = flashattn(BATCH, H, N_CTX, D_HEAD, casual, BLOCK_M, BLOCK_N) - ref_program = partial(ref_program, casual=casual) - mod, params = tl.lower(program) - mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) - mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) - - # latency = mod.do_bench(ref_program, warmup=500) - # print("{:.2f} ms".format(latency)) - # print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) - # latency = mod.do_bench(mod) - # print("{:.2f} ms".format(latency)) - # print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) \ No newline at end of file diff --git a/tl_scripts/profile.py b/tl_scripts/profile_workloads.py similarity index 100% rename from tl_scripts/profile.py rename to tl_scripts/profile_workloads.py diff --git a/tl_scripts/retnet_example.py b/tl_scripts/retnet_example.py index 2782078069db..567752e36b9f 100644 --- a/tl_scripts/retnet_example.py +++ b/tl_scripts/retnet_example.py @@ -91,6 +91,14 @@ def ref_program(Q, K, V, mask): o = torch.einsum('bhqk,bkhd->bqhd', qkm/r, V) return o.to(dtype=torch.float16) +def retnet_triton(Q, K, V, mask): + import sys + sys.path.append("/home/msra/cy/tvm.tl/3rdparty/flash-linear-attention") + from fla.ops.retention.parallel import parallel_retention + # Todo: mask + out = parallel_retention(Q, K, V) + return out + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/tl_verify/compile.py b/tl_verify/compile.py deleted file mode 100644 index 6504e61656bc..000000000000 --- a/tl_verify/compile.py +++ /dev/null @@ -1,46 +0,0 @@ -import os -import os.path as osp -# from tvm.contrib import nvcc -import subprocess - -with open("gemmx1.cu", "r") as f: - code = f.read() - -tvm_root = osp.join(osp.dirname(__file__), "../..") -tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl")) -if "TL_CUTLASS_PATH" in os.environ: - cutlass_path = os.environ["TL_CUTLASS_PATH"] -else: - cutlass_path = osp.abspath(osp.join(tvm_root, "3rdparty/cutlass/include")) - - -format = "ptx" -arch = f"sm_90a" - -# print(tl_template_path) -# print(cutlass_path) - -nvcc_command = [ - "nvcc", - "-o", "gemmx1", - "-arch=" + arch, - "--use_fast_math", - "-std=c++17", - "-I" + tl_template_path, - "-I" + cutlass_path, - "-lcuda", - "gemmx1.cu" -] - -subprocess.run(nvcc_command, check=True) - -# nvcc -ptx fa_kernel.cu -o fa_kernel.ptx -O3 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_BFLOAT16_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ -U__CUDA_NO_BFLOAT162_OPERATORS__ -U__CUDA_NO_BFLOAT162_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda -arch=sm_90a --use_fast_math -std=c++17 -I/home/msra/cy/tvm.tl/src/tl -I/home/msra/cy/tvm.tl/cutlass/include -lcuda -"-O3", -"-U__CUDA_NO_HALF_OPERATORS__", -"-U__CUDA_NO_HALF_CONVERSIONS__", -"-U__CUDA_NO_BFLOAT16_OPERATORS__", -"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", -"-U__CUDA_NO_BFLOAT162_OPERATORS__", -"-U__CUDA_NO_BFLOAT162_CONVERSIONS__", -"--expt-relaxed-constexpr", -"--expt-extended-lambda", \ No newline at end of file diff --git a/tl_verify/cuda_interface.cpp b/tl_verify/cuda_interface.cpp deleted file mode 100644 index 290187a527df..000000000000 --- a/tl_verify/cuda_interface.cpp +++ /dev/null @@ -1,57 +0,0 @@ -#include -#include -#include -#include -#include -#include "fa_kernel.hpp" - -void main_kernel_launcher(at::Tensor Q, at::Tensor K, at::Tensor V, at::Tensor output, bool causal); -void main_kernel_launcher_no_tma(at::Tensor Q, at::Tensor K, at::Tensor V, at::Tensor output); - -at::Tensor kernel_function(at::Tensor Q, at::Tensor K, at::Tensor V, bool causal) { - at::Tensor output = torch::empty_like(Q); - main_kernel_launcher(Q, K, V, output, causal); - return output; -} - -at::Tensor kernel_function_no_tma(at::Tensor Q, at::Tensor K, at::Tensor V) { - at::Tensor output = torch::empty_like(Q); - main_kernel_launcher_no_tma(Q, K, V, output); - return output; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("kernel_function", &kernel_function, "FA Kernel Function"); - m.def("kernel_function_no_tma", &kernel_function_no_tma, "FA Kernel Launcher"); -} - -void main_kernel_launcher(at::Tensor Q, at::Tensor K, at::Tensor V, at::Tensor output, bool causal) { - int batch = Q.size(0); - int seq_len = Q.size(1); - int heads = Q.size(2); - int dim = Q.size(3); - int block_M = 0; - int block_N = 0; - int threads = 0; - - if (dim == 64) { - block_M = 192; - block_N = 128; - threads = 16 * 32; - } else if (dim == 128) { - block_M = 128; - block_N = causal ? 128 : 176; - threads = 12 * 32; - } else if (dim == 256) { - block_M = 128; - block_N = 80; - threads = 12 * 32; - } else { - throw std::invalid_argument("Invalid dimension"); - } - host_function(Flash_fwd_params{Q.data_ptr(), K.data_ptr(), V.data_ptr(), output.data_ptr(), batch, seq_len, heads, dim, block_M, block_N, threads}); -} - -void main_kernel_launcher_no_tma(at::Tensor Q, at::Tensor K, at::Tensor V, at::Tensor output) { - host_function_no_tma(Flash_fwd_params{Q.data_ptr(), K.data_ptr(), V.data_ptr(), output.data_ptr(), Q.size(0), Q.size(1), Q.size(2), Q.size(3), 64, 64}); -} \ No newline at end of file diff --git a/tl_verify/fa_kernel.cu b/tl_verify/fa_kernel.cu deleted file mode 100644 index 746fbbf97094..000000000000 --- a/tl_verify/fa_kernel.cu +++ /dev/null @@ -1,429 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "fa_kernel.hpp" - -extern "C" __global__ void __launch_bounds__(512) main_kernel(__grid_constant__ const CUtensorMap K_desc, __grid_constant__ const CUtensorMap Output_desc, __grid_constant__ const CUtensorMap Q_desc, __grid_constant__ const CUtensorMap V_desc) { - extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; - float acc_o[32]; - float logsum[2]; - float scores_max[2]; - float acc_s[64]; - float scores_max_prev[2]; - float scores_scale[2]; - float scores_sum[2]; - half_t acc_s_cast[64]; - __shared__ uint64_t _mbarrier[11]; - if (((int)threadIdx.x) == 0) { - tl::prefetch_tma_descriptor(Q_desc); - tl::prefetch_tma_descriptor(K_desc); - tl::prefetch_tma_descriptor(V_desc); - tl::prefetch_tma_descriptor(Output_desc); - tl::mbarrier_init(_mbarrier[0], 128); - tl::mbarrier_init(_mbarrier[1], 128); - tl::mbarrier_init(_mbarrier[2], 128); - tl::mbarrier_init(_mbarrier[3], 128); - tl::mbarrier_init(_mbarrier[4], 384); - tl::mbarrier_init(_mbarrier[5], 384); - tl::mbarrier_init(_mbarrier[6], 384); - tl::mbarrier_init(_mbarrier[7], 384); - tl::mbarrier_init(_mbarrier[8], 128); - tl::mbarrier_init(_mbarrier[9], 384); - tl::mbarrier_init(_mbarrier[10], 384); - } - __syncthreads(); - if (384 <= ((int)threadIdx.x)) { - tl::warpgroup_reg_dealloc<32>(); - if (((int)threadIdx.x) == 384) { - tl::mbarrier_expect_tx(_mbarrier[8], 24576); - } - if (((int)threadIdx.x) == 384) { - tl::tma_load(Q_desc, _mbarrier[8], (&(((half_t*)buf_dyn_shmem)[0])), 0, ((int)blockIdx.y), (((int)blockIdx.x) * 192), ((int)blockIdx.z)); - } - tl::mbarrier_arrive(_mbarrier[8]); - for (int k = 0; k < 4; ++k) { - tl::mbarrier_wait(_mbarrier[((k & 1) + 4)], ((k >> 1) ^ 1)); - if (((int)threadIdx.x) == 384) { - tl::mbarrier_expect_tx(_mbarrier[(k & 1)], 16384); - } - if (((int)threadIdx.x) == 384) { - tl::tma_load(K_desc, _mbarrier[(k & 1)], (&(((half_t*)buf_dyn_shmem)[(((k & 1) * 8192) + 12288)])), 0, ((int)blockIdx.y), (k * 128), ((int)blockIdx.z)); - } - tl::mbarrier_arrive(_mbarrier[(k & 1)]); - tl::mbarrier_wait(_mbarrier[((k & 1) + 6)], ((k >> 1) ^ 1)); - if (((int)threadIdx.x) == 384) { - tl::mbarrier_expect_tx(_mbarrier[((k & 1) + 2)], 16384); - } - if (((int)threadIdx.x) == 384) { - tl::tma_load(V_desc, _mbarrier[((k & 1) + 2)], (&(((half_t*)buf_dyn_shmem)[(((k & 1) * 8192) + 28672)])), 0, ((int)blockIdx.y), (k * 128), ((int)blockIdx.z)); - } - tl::mbarrier_arrive(_mbarrier[((k & 1) + 2)]); - } - } else { - tl::warpgroup_reg_alloc<160>(); - #pragma unroll - for (int i = 0; i < 32; ++i) { - acc_o[i] = 0.000000e+00f; - } - #pragma unroll - for (int i_1 = 0; i_1 < 2; ++i_1) { - logsum[i_1] = 0.000000e+00f; - } - #pragma unroll - for (int i_2 = 0; i_2 < 2; ++i_2) { - scores_max[i_2] = -CUDART_INF_F; - } - tl::fence_proxy_async(); - tl::mbarrier_wait(_mbarrier[8], 0); - #pragma unroll - for (int i_3 = 0; i_3 < 64; ++i_3) { - acc_s[i_3] = 0.000000e+00f; - } - tl::fence_proxy_async(); - tl::mbarrier_wait(_mbarrier[0], 0); - tl::gemm_ss<192, 128, 64, 12, 1, 0, 1>((&(((half_t*)buf_dyn_shmem)[0])), (&(((half_t*)buf_dyn_shmem)[12288])), (&(acc_s[0]))); - tl::mbarrier_arrive(_mbarrier[4]); - #pragma unroll - for (int i_4 = 0; i_4 < 2; ++i_4) { - scores_max_prev[i_4] = scores_max[i_4]; - } - #pragma unroll - for (int i_5 = 0; i_5 < 2; ++i_5) { - scores_max[i_5] = -CUDART_INF_F; - } - #pragma unroll - for (int i_6 = 0; i_6 < 2; ++i_6) { - #pragma unroll - for (int rv = 0; rv < 32; ++rv) { - scores_max[i_6] = max(scores_max[i_6], acc_s[((((rv & 15) * 4) + (i_6 * 2)) + (rv >> 4))]); - } - scores_max[i_6] = tl::AllReduce::run(scores_max[i_6]); - } - #pragma unroll - for (int i_7 = 0; i_7 < 2; ++i_7) { - scores_scale[i_7] = exp2f(((scores_max_prev[i_7] * 1.803369e-01f) - (scores_max[i_7] * 1.803369e-01f))); - } - #pragma unroll - for (int i_8 = 0; i_8 < 64; ++i_8) { - acc_s[i_8] = exp2f(((acc_s[i_8] * 1.803369e-01f) - (scores_max[((i_8 & 3) >> 1)] * 1.803369e-01f))); - } - #pragma unroll - for (int i_9 = 0; i_9 < 2; ++i_9) { - scores_sum[i_9] = 0.000000e+00f; - #pragma unroll - for (int rv_1 = 0; rv_1 < 32; ++rv_1) { - scores_sum[i_9] = (scores_sum[i_9] + acc_s[((((rv_1 & 15) * 4) + (i_9 * 2)) + (rv_1 >> 4))]); - } - scores_sum[i_9] = tl::AllReduce::run(scores_sum[i_9]); - } - #pragma unroll - for (int i_10 = 0; i_10 < 2; ++i_10) { - logsum[i_10] = ((logsum[i_10] * scores_scale[i_10]) + scores_sum[i_10]); - } - #pragma unroll - for (int i_11 = 0; i_11 < 32; ++i_11) { - acc_o[i_11] = (acc_o[i_11] * scores_scale[((i_11 & 3) >> 1)]); - } - #pragma unroll - for (int i_12 = 0; i_12 < 64; ++i_12) { - acc_s_cast[i_12] = ((half_t)acc_s[i_12]); - } - #pragma unroll 1 -for (int k_1 = 0; k_1 < 3; ++k_1) { - #pragma unroll - for (int i_13 = 0; i_13 < 64; ++i_13) { - acc_s[i_13] = 0.000000e+00f; - } - tl::fence_proxy_async(); - tl::mbarrier_wait(_mbarrier[((k_1 + 1) & 1)], ((k_1 + 1) >> 1)); - tl::gemm_ss<192, 128, 64, 12, 1, 0, 1,-1>((&(((half_t*)buf_dyn_shmem)[0])), (&(((half_t*)buf_dyn_shmem)[((((k_1 + 1) & 1) * 8192) + 12288)])), (&(acc_s[0]))); - - tl::mbarrier_wait(_mbarrier[((k_1 & 1) + 2)], (k_1 >> 1)); - tl::gemm_rs<192, 64, 128, 12, 1, 0, 0,-1>((&(acc_s_cast[0])), (&(((half_t*)buf_dyn_shmem)[(((k_1 & 1) * 8192) + 28672)])), (&(acc_o[0]))); - - cute::warpgroup_wait<1>(); - tl::mbarrier_arrive(_mbarrier[(((k_1 + 1) & 1) + 4)]); - #pragma unroll - for (int i_14 = 0; i_14 < 2; ++i_14) { - scores_max_prev[i_14] = scores_max[i_14]; - } - #pragma unroll - for (int i_15 = 0; i_15 < 2; ++i_15) { - scores_max[i_15] = -CUDART_INF_F; - } - #pragma unroll - for (int i_16 = 0; i_16 < 2; ++i_16) { - #pragma unroll - for (int rv_2 = 0; rv_2 < 32; ++rv_2) { - scores_max[i_16] = max(scores_max[i_16], acc_s[((((rv_2 & 15) * 4) + (i_16 * 2)) + (rv_2 >> 4))]); - } - scores_max[i_16] = tl::AllReduce::run(scores_max[i_16]); - } - #pragma unroll - for (int i_17 = 0; i_17 < 2; ++i_17) { - scores_scale[i_17] = exp2f(((scores_max_prev[i_17] * 1.803369e-01f) - (scores_max[i_17] * 1.803369e-01f))); - } - #pragma unroll - for (int i_18 = 0; i_18 < 64; ++i_18) { - acc_s[i_18] = exp2f(((acc_s[i_18] * 1.803369e-01f) - (scores_max[((i_18 & 3) >> 1)] * 1.803369e-01f))); - } - #pragma unroll - for (int i_19 = 0; i_19 < 2; ++i_19) { - scores_sum[i_19] = 0.000000e+00f; - #pragma unroll - for (int rv_3 = 0; rv_3 < 32; ++rv_3) { - scores_sum[i_19] = (scores_sum[i_19] + acc_s[((((rv_3 & 15) * 4) + (i_19 * 2)) + (rv_3 >> 4))]); - } - scores_sum[i_19] = tl::AllReduce::run(scores_sum[i_19]); - } - #pragma unroll - for (int i_20 = 0; i_20 < 2; ++i_20) { - logsum[i_20] = ((logsum[i_20] * scores_scale[i_20]) + scores_sum[i_20]); - } - cute::warpgroup_wait<0>(); - tl::mbarrier_arrive(_mbarrier[((k_1 & 1) + 6)]); - #pragma unroll - for (int i_21 = 0; i_21 < 32; ++i_21) { - acc_o[i_21] = (acc_o[i_21] * scores_scale[((i_21 & 3) >> 1)]); - } - #pragma unroll - for (int i_22 = 0; i_22 < 64; ++i_22) { - acc_s_cast[i_22] = ((half_t)acc_s[i_22]); - } - } - tl::mbarrier_wait(_mbarrier[3], 1); - tl::gemm_rs<192, 64, 128, 12, 1, 0, 0>((&(acc_s_cast[0])), (&(((half_t*)buf_dyn_shmem)[36864])), (&(acc_o[0]))); - tl::mbarrier_arrive(_mbarrier[7]); - #pragma unroll - for (int i_23 = 0; i_23 < 32; ++i_23) { - acc_o[i_23] = (acc_o[i_23] / logsum[((i_23 & 3) >> 1)]); - } - tl::syncthreads_partial(_mbarrier[9]); - #pragma unroll - for (int i_24 = 0; i_24 < 4; ++i_24) { - tl::ptx_stmatrix_x4((&(((half_t*)buf_dyn_shmem)[(((((((int)threadIdx.x) >> 5) * 1024) + ((((int)threadIdx.x) & 15) * 64)) + (i_24 * 16)) + (((((int)threadIdx.x) & 31) >> 4) * 8))])), __pack_half2(((half_t)acc_o[(i_24 * 8)]), ((half_t)acc_o[((i_24 * 8) + 1)])), __pack_half2(((half_t)acc_o[((i_24 * 8) + 2)]), ((half_t)acc_o[((i_24 * 8) + 3)])), __pack_half2(((half_t)acc_o[((i_24 * 8) + 4)]), ((half_t)acc_o[((i_24 * 8) + 5)])), __pack_half2(((half_t)acc_o[((i_24 * 8) + 6)]), ((half_t)acc_o[((i_24 * 8) + 7)]))); - } - tl::fence_proxy_async(); - tl::syncthreads_partial(_mbarrier[10]); - if (((int)threadIdx.x) == 0) { - tl::tma_store(Output_desc, (&(((half_t*)buf_dyn_shmem)[0])), 0, ((int)blockIdx.y), (((int)blockIdx.x) * 192), ((int)blockIdx.z)); - } - } -} - -template -static std::string ArrayToStr(const T* ptr, size_t n) { - std::stringstream ss; - ss << "["; - for (size_t i = 0; i < n; i++) { - if (i > 0) ss << ", "; - ss << ptr[i]; - } - ss << "]"; - return ss.str(); -} - -struct TensorMapArgs { - CUtensorMap* map; - CUtensorMapDataType type; - cuuint32_t tensorRank; - void* globalAddress; - cuuint64_t globalDim[5], globalStride[5]; - cuuint32_t boxDim[5], elementStrides[5]; - CUtensorMapInterleave interleave; - CUtensorMapSwizzle swizzle; - CUtensorMapL2promotion l2Promotion; - CUtensorMapFloatOOBfill oobFill; - - std::string ToDebugString() { - std::stringstream ss; - ss << "TMA Desc Addr: " << map << std::endl - << "format " << type << std::endl - << "dim " << tensorRank << std::endl - << "gmem_address " << globalAddress << std::endl - << "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl - << "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl - << "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl - << "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl - << "interleave " << interleave << std::endl - << "swizzle " << swizzle << std::endl - << "l2Promotion " << l2Promotion << std::endl - << "oobFill " << oobFill << std::endl; - return ss.str(); - } -}; - -void host_function(Flash_fwd_params params) { - int num_m_blocks = (params.seq_len + params.block_M - 1) / params.block_M; - dim3 grid(num_m_blocks, params.head, params.batch); - dim3 block(params.threads); - size_t sharedMemSize = (params.block_M + 4 * params.block_N) * params.dim * sizeof(half_t); - - CUtensorMap Q_desc = {0}; - CUtensorMap K_desc = {0}; - CUtensorMap V_desc = {0}; - CUtensorMap O_desc = {0}; - TensorMapArgs Q_arg; - TensorMapArgs K_arg; - TensorMapArgs V_arg; - TensorMapArgs O_arg; - - Q_arg.map = &Q_desc; - Q_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - Q_arg.tensorRank = 4; - Q_arg.globalAddress = params.q_ptr; - Q_arg.globalDim[0] = static_cast(params.dim); - Q_arg.globalDim[1] = static_cast(params.head); - Q_arg.globalDim[2] = static_cast(params.seq_len); - Q_arg.globalDim[3] = static_cast(params.batch); - Q_arg.globalStride[0] = static_cast(2); - Q_arg.globalStride[1] = static_cast(2 * params.dim); - Q_arg.globalStride[2] = static_cast(2 * params.dim * params.head); - Q_arg.globalStride[3] = static_cast(2 * params.dim * params.head * params.seq_len); - Q_arg.boxDim[0] = static_cast(64); - Q_arg.boxDim[1] = static_cast(1); - Q_arg.boxDim[2] = static_cast(params.block_M); - Q_arg.boxDim[3] = static_cast(1); - Q_arg.elementStrides[0] = static_cast(1); - Q_arg.elementStrides[1] = static_cast(1); - Q_arg.elementStrides[2] = static_cast(1); - Q_arg.elementStrides[3] = static_cast(1); - Q_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; - Q_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - Q_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; - Q_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; - - K_arg.map = &K_desc; - K_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - K_arg.tensorRank = 4; - K_arg.globalAddress = params.k_ptr; - K_arg.globalDim[0] = static_cast(params.dim); - K_arg.globalDim[1] = static_cast(params.head); - K_arg.globalDim[2] = static_cast(params.seq_len); - K_arg.globalDim[3] = static_cast(params.batch); - K_arg.globalStride[0] = static_cast(2); - K_arg.globalStride[1] = static_cast(2 * params.dim); - K_arg.globalStride[2] = static_cast(2 * params.dim * params.head); - K_arg.globalStride[3] = static_cast(2 * params.dim * params.head * params.seq_len); - K_arg.boxDim[0] = static_cast(64); - K_arg.boxDim[1] = static_cast(1); - K_arg.boxDim[2] = static_cast(params.block_N); - K_arg.boxDim[3] = static_cast(1); - K_arg.elementStrides[0] = static_cast(1); - K_arg.elementStrides[1] = static_cast(1); - K_arg.elementStrides[2] = static_cast(1); - K_arg.elementStrides[3] = static_cast(1); - K_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; - K_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - K_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; - K_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; - - V_arg.map = &V_desc; - V_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - V_arg.tensorRank = 4; - V_arg.globalAddress = params.v_ptr; - V_arg.globalDim[0] = static_cast(params.dim); - V_arg.globalDim[1] = static_cast(params.head); - V_arg.globalDim[2] = static_cast(params.seq_len); - V_arg.globalDim[3] = static_cast(params.batch); - V_arg.globalStride[0] = static_cast(2); - V_arg.globalStride[1] = static_cast(2 * params.dim); - V_arg.globalStride[2] = static_cast(2 * params.dim * params.head); - V_arg.globalStride[3] = static_cast(2 * params.dim * params.head * params.seq_len); - V_arg.boxDim[0] = static_cast(64); - V_arg.boxDim[1] = static_cast(1); - V_arg.boxDim[2] = static_cast(params.block_N); - V_arg.boxDim[3] = static_cast(1); - V_arg.elementStrides[0] = static_cast(1); - V_arg.elementStrides[1] = static_cast(1); - V_arg.elementStrides[2] = static_cast(1); - V_arg.elementStrides[3] = static_cast(1); - V_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; - V_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - V_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; - V_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; - - O_arg.map = &O_desc; - O_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - O_arg.tensorRank = 4; - O_arg.globalAddress = params.output_ptr; - O_arg.globalDim[0] = static_cast(params.dim); - O_arg.globalDim[1] = static_cast(params.head); - O_arg.globalDim[2] = static_cast(params.seq_len); - O_arg.globalDim[3] = static_cast(params.batch); - O_arg.globalStride[0] = static_cast(2); - O_arg.globalStride[1] = static_cast(2 * params.dim); - O_arg.globalStride[2] = static_cast(2 * params.dim * params.head); - O_arg.globalStride[3] = static_cast(2 * params.dim * params.head * params.seq_len); - O_arg.boxDim[0] = static_cast(64); - O_arg.boxDim[1] = static_cast(1); - O_arg.boxDim[2] = static_cast(params.block_M); - O_arg.boxDim[3] = static_cast(1); - O_arg.elementStrides[0] = static_cast(1); - O_arg.elementStrides[1] = static_cast(1); - O_arg.elementStrides[2] = static_cast(1); - O_arg.elementStrides[3] = static_cast(1); - O_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; - O_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; - O_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; - O_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; - - CUresult result; - result = cuTensorMapEncodeTiled( - Q_arg.map, Q_arg.type, Q_arg.tensorRank, Q_arg.globalAddress, Q_arg.globalDim, Q_arg.globalStride + 1, Q_arg.boxDim, - Q_arg.elementStrides, Q_arg.interleave, Q_arg.swizzle, Q_arg.l2Promotion, Q_arg.oobFill); - if (result != CUDA_SUCCESS) { - std::cout << "Failed to initialize the TMA descriptor " << result << std::endl - << Q_arg.ToDebugString(); - } - - result = cuTensorMapEncodeTiled( - K_arg.map, K_arg.type, K_arg.tensorRank, K_arg.globalAddress, K_arg.globalDim, K_arg.globalStride + 1, K_arg.boxDim, - K_arg.elementStrides, K_arg.interleave, K_arg.swizzle, K_arg.l2Promotion, K_arg.oobFill); - if (result != CUDA_SUCCESS) { - std::cout << "Failed to initialize the TMA descriptor " << result << std::endl - << K_arg.ToDebugString(); - } - - result = cuTensorMapEncodeTiled( - V_arg.map, V_arg.type, V_arg.tensorRank, V_arg.globalAddress, V_arg.globalDim, V_arg.globalStride + 1, V_arg.boxDim, - V_arg.elementStrides, V_arg.interleave, V_arg.swizzle, V_arg.l2Promotion, V_arg.oobFill); - if (result != CUDA_SUCCESS) { - std::cout << "Failed to initialize the TMA descriptor " << result << std::endl - << V_arg.ToDebugString(); - } - - result = cuTensorMapEncodeTiled( - O_arg.map, O_arg.type, O_arg.tensorRank, O_arg.globalAddress, O_arg.globalDim, O_arg.globalStride + 1, O_arg.boxDim, - O_arg.elementStrides, O_arg.interleave, O_arg.swizzle, O_arg.l2Promotion, O_arg.oobFill); - if (result != CUDA_SUCCESS) { - std::cout << "Failed to initialize the TMA descriptor " << result << std::endl - << O_arg.ToDebugString(); - } - - const int MAXBYTES = 1024 * 226; - cudaFuncSetAttribute(main_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, MAXBYTES); - - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - std::cerr << "CUDA device synchronization failed: " << cudaGetErrorString(err) << std::endl; - return; - } - - main_kernel<<>>(K_desc, O_desc, Q_desc, V_desc); - - err = cudaGetLastError(); - if (err != cudaSuccess) { - std::cerr << "CUDA kernel launch failed: " << cudaGetErrorString(err) << std::endl; - return; - } - - err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - std::cerr << "CUDA device synchronization failed: " << cudaGetErrorString(err) << std::endl; - return; - } -} \ No newline at end of file diff --git a/tl_verify/fa_kernel.hpp b/tl_verify/fa_kernel.hpp deleted file mode 100644 index 9b3103af1d52..000000000000 --- a/tl_verify/fa_kernel.hpp +++ /dev/null @@ -1,26 +0,0 @@ -#pragma once - -#include -#include - -struct Flash_fwd_params -{ - using index_t = int64_t; - // The QKV matrices. - void *__restrict__ q_ptr; - void *__restrict__ k_ptr; - void *__restrict__ v_ptr; - void *__restrict__ output_ptr; - - index_t batch; - index_t seq_len; - index_t head; - index_t dim; - index_t block_M; - index_t block_N; - index_t threads; -}; - -void host_function(Flash_fwd_params params); -void host_function_no_tma(Flash_fwd_params params); - diff --git a/tl_verify/fa_no_tma.cu b/tl_verify/fa_no_tma.cu deleted file mode 100644 index 65ae4030f96f..000000000000 --- a/tl_verify/fa_no_tma.cu +++ /dev/null @@ -1,199 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "fa_kernel.hpp" - -template -__device__ void print_(T* reg, const char* name, int range, int total_threads) { - __syncthreads(); - if ((int)blockIdx.x == 0) { - if ((int)threadIdx.x == 0) { - printf("\n%s:\n", name); - } - for (int tid = 0; tid < total_threads; tid++) { - __syncthreads(); - if (threadIdx.x == tid) { - printf("tid: %d: ", tid); - for (int i = 0; i < range; i++) { - printf("%f ", float(reg[i])); - } - printf("\n"); - } - } - } - __syncthreads(); -} - - -extern "C" __global__ void __launch_bounds__(128) main_kernel_no_tma(__grid_constant__ const CUtensorMap K_desc, half_t* __restrict__ Output, __grid_constant__ const CUtensorMap Q_desc, __grid_constant__ const CUtensorMap V_desc) { -} - - -struct TensorMapArgs { - CUtensorMap* map; - CUtensorMapDataType type; - cuuint32_t tensorRank; - void* globalAddress; - cuuint64_t globalDim[5], globalStride[5]; - cuuint32_t boxDim[5], elementStrides[5]; - CUtensorMapInterleave interleave; - CUtensorMapSwizzle swizzle; - CUtensorMapL2promotion l2Promotion; - CUtensorMapFloatOOBfill oobFill; - - std::string ToDebugString() { - std::stringstream ss; - // ss << "TMA Desc Addr: " << map << std::endl - // << "format " << type << std::endl - // << "dim " << tensorRank << std::endl - // << "gmem_address " << globalAddress << std::endl - // << "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl - // << "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl - // << "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl - // << "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl - // << "interleave " << interleave << std::endl - // << "swizzle " << swizzle << std::endl - // << "l2Promotion " << l2Promotion << std::endl - // << "oobFill " << oobFill << std::endl; - return ss.str(); - } -}; - -void host_function_no_tma(Flash_fwd_params params) { - int num_m_blocks = (params.seq_len + params.block_M - 1) / params.block_M; - dim3 grid(num_m_blocks, params.head, params.batch); - dim3 block(128); - size_t sharedMemSize = (params.block_M + 2 * params.block_N) * params.dim * sizeof(half_t); // 24576; - - // int size = params.batch * params.head * params.seq_len * params.dim * sizeof(half_t); - - CUtensorMap Q_desc = {0}; - CUtensorMap K_desc = {0}; - CUtensorMap V_desc = {0}; - TensorMapArgs Q_arg; - TensorMapArgs K_arg; - TensorMapArgs V_arg; - - Q_arg.map = &Q_desc; - Q_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - Q_arg.tensorRank = 4; - Q_arg.globalAddress = params.q_ptr; - Q_arg.globalDim[0] = static_cast(params.dim); - Q_arg.globalDim[1] = static_cast(params.head); - Q_arg.globalDim[2] = static_cast(params.seq_len); - Q_arg.globalDim[3] = static_cast(params.batch); - Q_arg.globalStride[0] = static_cast(2); - Q_arg.globalStride[1] = static_cast(128); - Q_arg.globalStride[2] = static_cast(128); - Q_arg.globalStride[3] = static_cast(32768); - Q_arg.boxDim[0] = static_cast(64); - Q_arg.boxDim[1] = static_cast(1); - Q_arg.boxDim[2] = static_cast(64); - Q_arg.boxDim[3] = static_cast(1); - Q_arg.elementStrides[0] = static_cast(1); - Q_arg.elementStrides[1] = static_cast(1); - Q_arg.elementStrides[2] = static_cast(1); - Q_arg.elementStrides[3] = static_cast(1); - Q_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; - Q_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - Q_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; - Q_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; - - K_arg.map = &K_desc; - K_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - K_arg.tensorRank = 4; - K_arg.globalAddress = params.k_ptr; - K_arg.globalDim[0] = static_cast(64); - K_arg.globalDim[1] = static_cast(1); - K_arg.globalDim[2] = static_cast(256); - K_arg.globalDim[3] = static_cast(1); - K_arg.globalStride[0] = static_cast(2); - K_arg.globalStride[1] = static_cast(128); - K_arg.globalStride[2] = static_cast(128); - K_arg.globalStride[3] = static_cast(32768); - K_arg.boxDim[0] = static_cast(64); - K_arg.boxDim[1] = static_cast(1); - K_arg.boxDim[2] = static_cast(64); - K_arg.boxDim[3] = static_cast(1); - K_arg.elementStrides[0] = static_cast(1); - K_arg.elementStrides[1] = static_cast(1); - K_arg.elementStrides[2] = static_cast(1); - K_arg.elementStrides[3] = static_cast(1); - K_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; - K_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - K_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; - K_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; - - V_arg.map = &V_desc; - V_arg.type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - V_arg.tensorRank = 4; - V_arg.globalAddress = params.v_ptr; - V_arg.globalDim[0] = static_cast(64); - V_arg.globalDim[1] = static_cast(1); - V_arg.globalDim[2] = static_cast(256); - V_arg.globalDim[3] = static_cast(1); - V_arg.globalStride[0] = static_cast(2); - V_arg.globalStride[1] = static_cast(128); - V_arg.globalStride[2] = static_cast(128); - V_arg.globalStride[3] = static_cast(32768); - V_arg.boxDim[0] = static_cast(64); - V_arg.boxDim[1] = static_cast(1); - V_arg.boxDim[2] = static_cast(64); - V_arg.boxDim[3] = static_cast(1); - V_arg.elementStrides[0] = static_cast(1); - V_arg.elementStrides[1] = static_cast(1); - V_arg.elementStrides[2] = static_cast(1); - V_arg.elementStrides[3] = static_cast(1); - V_arg.interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; - V_arg.swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - V_arg.l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; - V_arg.oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; - - CUresult result; - result = cuTensorMapEncodeTiled( - Q_arg.map, Q_arg.type, Q_arg.tensorRank, Q_arg.globalAddress, Q_arg.globalDim, Q_arg.globalStride + 1, Q_arg.boxDim, - Q_arg.elementStrides, Q_arg.interleave, Q_arg.swizzle, Q_arg.l2Promotion, Q_arg.oobFill); - if (result != CUDA_SUCCESS) { - std::cout << "Failed to initialize the TMA descriptor " << result << std::endl - << Q_arg.ToDebugString(); - } - - result = cuTensorMapEncodeTiled( - K_arg.map, K_arg.type, K_arg.tensorRank, K_arg.globalAddress, K_arg.globalDim, K_arg.globalStride + 1, K_arg.boxDim, - K_arg.elementStrides, K_arg.interleave, K_arg.swizzle, K_arg.l2Promotion, K_arg.oobFill); - if (result != CUDA_SUCCESS) { - std::cout << "Failed to initialize the TMA descriptor " << result << std::endl - << K_arg.ToDebugString(); - } - - result = cuTensorMapEncodeTiled( - V_arg.map, V_arg.type, V_arg.tensorRank, V_arg.globalAddress, V_arg.globalDim, V_arg.globalStride + 1, V_arg.boxDim, - V_arg.elementStrides, V_arg.interleave, V_arg.swizzle, V_arg.l2Promotion, V_arg.oobFill); - if (result != CUDA_SUCCESS) { - std::cout << "Failed to initialize the TMA descriptor " << result << std::endl - << V_arg.ToDebugString(); - } - - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - std::cerr << "CUDA device synchronization failed: " << cudaGetErrorString(err) << std::endl; - return; - } - - main_kernel_no_tma<<>>(K_desc, (half_t*)params.output_ptr, Q_desc, V_desc); - - err = cudaGetLastError(); - if (err != cudaSuccess) { - std::cerr << "CUDA kernel launch failed: " << cudaGetErrorString(err) << std::endl; - return; - } - - err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - std::cerr << "CUDA device synchronization failed: " << cudaGetErrorString(err) << std::endl; - return; - } -} \ No newline at end of file diff --git a/tl_verify/main.py b/tl_verify/main.py deleted file mode 100644 index 9fb8b9ef9e41..000000000000 --- a/tl_verify/main.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -import fa_test -from flash_attn.flash_attn_interface import flash_attn_func -import random -import numpy as np - -def set_seed(seed): - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. - np.random.seed(seed) # Numpy module. - random.seed(seed) # Python random module. - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - -def ref_program(Q, K, V, casual): - # from flash_attn.flash_attn_interface import flash_attn_func - - # return flash_attn_func(Q, K, V, causal=casual) - assert casual == False, "casual is not supported" - batch, seq_len, heads, dim = Q.size() - scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) - block_M = seq_len - block_N = 64 if dim <= 128 else 32 - acc_s = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float) - acc_s_cast = torch.empty((batch, heads, block_M, block_N), device="cuda", dtype=torch.float16) - acc_o = torch.empty((batch, block_M, heads, dim), device="cuda", dtype=torch.float) - scores_max = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) - scores_max_prev = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) - scores_scale = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) - scores_sum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) - logsum = torch.empty((batch, heads, block_M), device="cuda", dtype=torch.float) - acc_o.fill_(0) - logsum.fill_(0) - scores_max.fill_(float('-inf')) - Q_scaled = Q * scale - - for i in range(int(seq_len / block_N)): - acc_s.fill_(0) - acc_s = torch.einsum('bqhd,bkhd->bhqk', Q_scaled, K[:, i * block_N : (i + 1) * block_N, :, :]) # [batch, seqlen, heads, block_N] - # scores_max_prev = scores_max - # scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] - # scores_scale = torch.exp2(scores_max_prev - scores_max) - # acc_o *= scores_scale[:, :, :, None].transpose(1, 2) - acc_s = torch.exp2(acc_s - 32) - acc_s_cast = acc_s.to(torch.float16) - acc_o += torch.einsum('bhqk,bkhd->bqhd', acc_s_cast, V[:, i * block_N : (i + 1) * block_N, :, :]) - # scores_sum = acc_s.sum(dim=-1, keepdim=False) - # logsum = logsum * scores_scale + scores_sum - # acc_o /= logsum[:, :, :, None].transpose(1, 2) - return acc_o.to(torch.float16) - -set_seed(42) -causal = False -batch, seq_len, heads, dim = 64, 512, 16, 64 -shape = [batch, seq_len, heads, dim] -# q = torch.empty(*shape, device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) -# k = torch.empty(*shape, device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) -q = torch.ones(*shape, device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) -k = torch.ones(*shape, device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) -v = torch.empty(*shape, device='cuda', dtype=torch.float16).normal_(-1.0, 1.0) - -output = fa_test.kernel_function(q, k, v, causal) -ref_output = flash_attn_func(q, k, v, causal=False) -# ref_output = ref_program(q, k, v, causal) -assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2) -print("Check: PASSED") - -warmups = 10 -runs = 10 -for _ in range(warmups): - out = fa_test.kernel_function(q, k, v, causal) - -start_event = torch.cuda.Event(enable_timing=True) -end_event = torch.cuda.Event(enable_timing=True) - -start_event.record() - -for _ in range(runs): - out = fa_test.kernel_function(q, k, v, causal) - -end_event.record() -torch.cuda.synchronize() - -latency = start_event.elapsed_time(end_event) - -flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim -total_flops = 2 * flops_per_matmul -print(f"total_flops: {total_flops}") -print(f"TFLOPS: {total_flops / latency * runs * 1e-9}") -print(f"Latency: {latency / runs:.2f} ms") \ No newline at end of file diff --git a/tl_verify/setup.py b/tl_verify/setup.py deleted file mode 100644 index f2fd38dcbc64..000000000000 --- a/tl_verify/setup.py +++ /dev/null @@ -1,110 +0,0 @@ -from setuptools import setup -import torch.utils.cpp_extension -import subprocess -from packaging.version import parse, Version -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME -torch.utils.cpp_extension.CUDAExtension.debug = True - - -def append_nvcc_threads(nvcc_extra_args): - return nvcc_extra_args + ["--threads", "4"] - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("12.3"): - raise RuntimeError("FA Hopper is only supported on CUDA 12.3 and above") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_90a,code=sm_90a") - -nvcc_flags = [ - "-O3", - # "-O0", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - # "--ptxas-options=-v", # printing out number of registers - "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers - "-lineinfo", - "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging - "-DNDEBUG", # Important, otherwise performance is severely impacted - "-DQBLKSIZE=128", - "-DKBLKSIZE=128", - "-DCTA256", - "-DDQINRMEM", - # "-keep" - ] - -# extra_compile_args = { -# 'cxx': ['-O3', '-std=c++17'], -# 'nvcc': [ -# # '-arch=sm_90a', -# '-gencode arch=compute_90a,code=compute_90a', -# '--use_fast_math', -# '-std=c++17', -# "-O3", -# "-U__CUDA_NO_HALF_OPERATORS__", -# "-U__CUDA_NO_HALF_CONVERSIONS__", -# "-U__CUDA_NO_BFLOAT16_OPERATORS__", -# "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", -# "-U__CUDA_NO_BFLOAT162_OPERATORS__", -# "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", -# "--expt-relaxed-constexpr", -# "--expt-extended-lambda", -# "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers -# '-I/usr/local/cuda/include', -# '-I/home/msra/cy/tvm.tl/src/tl', -# '-I/home/msra/cy/tvm.tl/cutlass/include', -# '-lcuda', -# '-lineinfo', -# "-lnvToolsExt", -# "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging -# "-DNDEBUG", # Important, otherwise performance is severely impacted -# # '-keep' # Uncomment this line to keep the generated .ptx file -# ], -# } - -extra_compile_args = { - "cxx": ["-O3", "-std=c++17"], - "nvcc": append_nvcc_threads( - nvcc_flags + ["-DEXECMODE=0"] + cc_flag - ), -} - -include_dirs = [ - '/home/msra/cy/tvm.tl/src/tl', - '/home/msra/cy/tvm.tl/cutlass/include', - '/usr/local/cuda/include' -] - -setup( - name='fa_test', - ext_modules=[ - CUDAExtension( - 'fa_test', - sources=['cuda_interface.cpp', 'fa_kernel.cu', 'fa_no_tma.cu'], - extra_compile_args=extra_compile_args, - include_dirs=include_dirs, - libraries=["cuda"] - ), - ], - cmdclass={ - 'build_ext': BuildExtension - } -) - -# sudo -E env PATH=$PATH PYTHONPATH=$PYTHONPATH TMPDIR=~/cy/ncu_tmp ncu --set full -k regex:"main_kernel" --launch-count 1 --launch-skip 10 --target-processes application-only --cache-control none --clock-control none --apply-rules yes --import-source yes --check-exit-code yes -f -o reports/tl_8_2048_8_256_false /home/msra/miniconda3/envs/tl/bin/python main.py \ No newline at end of file From c59e6dd74418746dd9f3f97e22dcc979df35748a Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Wed, 2 Oct 2024 15:45:59 +0000 Subject: [PATCH 22/23] [tl] Update --- python/tvm/tl/engine.py | 8 +- python/tvm/tl/utils.py | 6 + src/tl/transform/warp_specialized_rewriter.cc | 221 ++++++++++++++++-- 3 files changed, 212 insertions(+), 23 deletions(-) diff --git a/python/tvm/tl/engine.py b/python/tvm/tl/engine.py index 3e2eaacef4de..b99ec45b7adc 100644 --- a/python/tvm/tl/engine.py +++ b/python/tvm/tl/engine.py @@ -21,6 +21,7 @@ import tvm from tvm import tir, tl, relay from tvm.contrib import nvcc +from tvm.tl.code_replace import replace_code def is_device_call(func: tir.PrimFunc): @@ -55,10 +56,12 @@ def tvm_callback_cuda_compile(code, target): arch, options=[ "-std=c++17", + "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers "--use_fast_math", "-I" + tl_template_path, "-I" + cutlass_path, ], + get_output=True, ) # with open("save.ptx", "wb") as f: # f.write(ptx) @@ -70,11 +73,6 @@ def extrac_params(func: tir.PrimFunc): tensor_types = [relay.TensorType(buffer.shape, buffer.dtype) for buffer in buffers] return tensor_types -@tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) -def tvm_callback_cuda_postproc(code, _): - code = code.replace("""original code""", -"""modified code""") - return code def lower(func): params = extrac_params(func) diff --git a/python/tvm/tl/utils.py b/python/tvm/tl/utils.py index 90c3e95ad61f..c12adf173c73 100644 --- a/python/tvm/tl/utils.py +++ b/python/tvm/tl/utils.py @@ -135,6 +135,12 @@ def assert_allclose(self, reference_program: callable, atol: float = 1e-8, rtol: assert len(lib_outs) == len(ref_outs) # torch.set_printoptions(edgeitems=torch.inf) for lhs, rhs in zip(lib_outs, ref_outs): + # close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol) + # total_elements = lhs.numel() + # num_not_close = (~close_mask).sum().item() + # percentage_not_close = (num_not_close / total_elements) * 100 + # print(f"{percentage_not_close:.2f}% of the elements are not close.") + # print(f"Total elements: {total_elements}, Not close elements: {num_not_close}") assert torch.allclose(lhs, rhs, rtol=rtol, atol=atol), (lhs, rhs) def assert_consistent(self, repeat=10): diff --git a/src/tl/transform/warp_specialized_rewriter.cc b/src/tl/transform/warp_specialized_rewriter.cc index 37b1d6a9b5b2..c1408db5cc51 100644 --- a/src/tl/transform/warp_specialized_rewriter.cc +++ b/src/tl/transform/warp_specialized_rewriter.cc @@ -160,6 +160,22 @@ static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) { return Evaluate(call); } +static bool isGemm(Stmt stmt) { + bool is_gemm = false; + if (stmt.as()) { + auto call = Downcast(stmt)->value.as(); + if (call && call->op.same_as(Op::Get("tir.call_extern"))) { + if (call->args[0].as()) { + std::string name = Downcast(call->args[0])->value; + if (name.find("gemm") != std::string::npos) { + is_gemm = true; + } + } + } + } + return is_gemm; +} + class ProducerTraitsCollector : public StmtExprVisitor { public: ProducerTraitsCollector() { Clear(); } @@ -256,9 +272,107 @@ Block MakeGroupBlock(const Stmt& stmt, const Map& annotations return block; } +struct OpInfo { + int group_size, order, stage; + std::vector group; +}; +struct PipelineInfo { + std::vector op_infos; + + PipelineInfo() = default; + PipelineInfo( + Array> group_info, + Array order_info, + Array stage_info + ) { + int n = static_cast(group_info.size()); + ICHECK(n == static_cast(order_info.size())); + ICHECK(n == static_cast(stage_info.size())); + int cur_id = 0; + for (int i = 0; i < n; i++) { + OpInfo op_info; + op_info.group_size = group_info[i].size(); + for (int j = 0; j < op_info.group_size; j++) { + op_info.group.push_back(group_info[i][j].as()->value); + } + op_info.order = order_info[i].as()->value; + op_info.stage = stage_info[i].as()->value; + op_infos.push_back(op_info); + } + } + + PipelineInfo(const PipelineInfo& other) { + for (auto op_info : other.op_infos) { + op_infos.push_back(op_info); + } + } + + std::pair FindStmt(int stmt_idx) { + for (size_t i = 0; i < op_infos.size(); i++) { + for (size_t j = 0; j < op_infos[i].group.size(); j++) { + if (op_infos[i].group[j] == stmt_idx) { + return std::make_pair(i, j); + } + } + } + return std::make_pair(-1, -1); + } + + void UpdateOrder(int order) { + for (int i = 0; i < static_cast(op_infos.size()); i++) { + if (op_infos[i].order >= order && op_infos[i].order > 0) { + op_infos[i].order++; + } + } + } + + int SplitOp(int stmt_idx) { + auto pair = FindStmt(stmt_idx); + int op_idx = pair.first; + int inner_idx = pair.second; + ICHECK(op_idx != -1); + ICHECK(inner_idx != -1); + OpInfo half0; + OpInfo half1; + // The order to do sync + int sync_order = op_infos[op_idx].order + 1; + UpdateOrder(sync_order); + + half0.group_size = inner_idx + 1; + half0.order = op_infos[op_idx].order; + half0.stage = op_infos[op_idx].stage; + for (int i = 0; i <= inner_idx; i++) { + half0.group.push_back(op_infos[op_idx].group[i]); + } + half1.group_size = op_infos[op_idx].group_size - inner_idx - 1; + half1.order = op_infos[op_idx].order + 2; + half1.stage = op_infos[op_idx].stage; + for (int i = inner_idx + 1; i < op_infos[op_idx].group_size; i++) { + half1.group.push_back(op_infos[op_idx].group[i]); + } + op_infos.erase(op_infos.begin() + op_idx); + if (half0.group_size > 0) { + op_infos.insert(op_infos.begin() + op_idx, half0); + } + if (half1.group_size > 0) { + UpdateOrder(half1.order); + op_infos.insert(op_infos.begin() + op_idx + 1, half1); + } + return sync_order; + } + + void PrintPipelineInfo() { + std::cout << "Print op_infos:" << std::endl; + for (size_t i = 0; i < op_infos.size(); i++) { + std::cout << i << " " << op_infos[i].group_size << " " << op_infos[i].order << " " << op_infos[i].stage << std::endl; + } + std::cout << "End of print" << std::endl; + } +}; + class GroupOpRewriter : public StmtExprMutator { public: - GroupOpRewriter(Array>& group_info) : group_info_(group_info) {} + GroupOpRewriter(PipelineInfo pipeline_info) : pipeline_info_(pipeline_info) {} private: Stmt VisitStmt_(const ForNode* op) final { @@ -269,25 +383,37 @@ class GroupOpRewriter : public StmtExprMutator { return GetRef(op); } Array new_body; - for (size_t i = 0; i < group_info_.size(); i++) { - if (group_info_[i].size() == 0) continue; + int cur_id = 0; + for (int i = 0; i < static_cast(pipeline_info_.op_infos.size()); i++) { + if (pipeline_info_.op_infos[i].group_size == 0) continue; Array block_stmt; - for (size_t j = 0; j < group_info_[i].size(); j++) { - ICHECK(group_info_[i][j].as()); - int index = static_cast(group_info_[i][j].as()->value); - ICHECK(original_node->seq[index].as()); - auto block = original_node->seq[index].as(); + for (int j = 0; j < static_cast(pipeline_info_.op_infos[i].group_size); j++) { + // ICHECK(group_info_[i][j].as()); + // int index = static_cast(group_info_[i][j].as()->value); + ICHECK(original_node->seq[cur_id].as()); + auto block = original_node->seq[cur_id].as(); // TODO: handle nested seqstmt block_stmt.push_back(block->body); + cur_id++; } new_body.push_back( MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations)); } - For new_for = For(op->loop_var, op->min, op->extent, op->kind, new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)), op->thread_binding, op->annotations); + Array order_anno; + Array stage_anno; + for (auto op_info : pipeline_info_.op_infos) { + order_anno.push_back(Integer(op_info.order)); + stage_anno.push_back(Integer(op_info.stage)); + } + Map for_annotations = op->annotations; + for_annotations.erase("software_pipeline_group"); + for_annotations.Set("software_pipeline_order", order_anno); + for_annotations.Set("software_pipeline_stage", stage_anno); + For new_for = For(op->loop_var, op->min, op->extent, op->kind, new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)), op->thread_binding, for_annotations); return new_for; } - Array> group_info_; + PipelineInfo pipeline_info_; }; class WSCodeEmitter : public StmtMutator { public: @@ -325,6 +451,16 @@ class WSCodeEmitter : public StmtMutator { auto seq_transformed = op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); }); auto map = ExtractSyncPattern(op->seq); + // std::cout << "Print ExtractSyncPattern" << std::endl; + // for (int i = 0; i < static_cast(op->seq.size()); i++) { + // std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " " << map.release_after[i] << std::endl; + // } + // std::cout << "Print sync pattern" << std::endl; + // for (auto pattern : map.patterns) { + // std::cout << pattern.release_idx << " " << pattern.acquire_idx << std::endl; + // } + // std::cout << "End of ExtractSyncPattern" << std::endl; + // pipeline_info_.PrintPipelineInfo(); Array new_body; Map annotations; annotations.Set(String("stmt_group"), Integer(1)); @@ -378,15 +514,37 @@ class WSCodeEmitter : public StmtMutator { block_stmt.push_back(makeParityWait(acquire_barrier_id, parity)); } block_stmt.push_back(seq_transformed[i]); + // new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations)); if (map.release_after[i]) { PrimExpr release_barrier_id = stage_ + num_barriers_ + num_stages_ * map.release[i]; block_stmt.push_back(makeArriveBarrier(release_barrier_id)); for (int j = 0; j < num_stages_; j++) { released_barrier_.insert(j + num_barriers_ + num_stages_ * map.release[i]); } + // Update the pipeline info + // Todo: handle sync } new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations)); } + // Filter out the producer stmts + int cur_id = 0; + PipelineInfo new_pipeline_info; + for (int i = 0; i < static_cast(pipeline_info_.op_infos.size()); i++) { + auto op_info = pipeline_info_.op_infos[i]; + bool is_producer = false; + for (int j = 0; j < op_info.group_size; j++) { + if (marker_.GetRole(op->seq[cur_id]) == Role::kProducer) { + is_producer = true; + } + cur_id++; + } + if (is_producer) { + ICHECK(op_info.group_size == 1); + } else { + new_pipeline_info.op_infos.push_back(op_info); + } + } + pipeline_info_ = new_pipeline_info; } num_barriers_ += map.patterns.size() * num_stages_; @@ -404,39 +562,65 @@ class WSCodeEmitter : public StmtMutator { ICHECK(num_stages_ == 1) << "Nested pipeline not supported."; } + Array> group_info_array; + Array order_info_array; + Array stage_info_array; + + auto group_anno = op->annotations.Get("software_pipeline_group"); + if (group_anno.defined()) { + group_info_array = Downcast>>(group_anno); + } + auto order_anno = op->annotations.Get("software_pipeline_order"); + if (order_anno.defined()) { + order_info_array = Downcast>(order_anno); + } + auto stage_anno = op->annotations.Get("software_pipeline_stage"); + if (stage_anno.defined()) { + stage_info_array = Downcast>(stage_anno); + } + + PipelineInfo pipeline_info(group_info_array, order_info_array, stage_info_array); + if (pipeline_info.op_infos.size() > 0) { + ICHECK(pipeline_info_.op_infos.size() == 0) << "Nested pipeline not supported."; + } + PrimExpr parity_before = std::move(parity_); PrimExpr stage_before = std::move(stage_); int num_stages_before = num_stages_; + PipelineInfo pipeline_info_before = pipeline_info_; num_stages_ = num_stages; + pipeline_info_ = pipeline_info; stage_ = FloorMod(op->loop_var - op->min, num_stages); parity_ = FloorMod(parity_before * op->extent + FloorDiv(op->loop_var - op->min, num_stages), 2); auto result = FilterByRole(op); + Stmt grouped_for_node; + if (result.as() && group_anno.defined() && group_info_array.size() > 0 && !is_emitting_producer_) { + GroupOpRewriter group_op_rewriter(pipeline_info_); + auto for_node = Downcast(result); + grouped_for_node = group_op_rewriter(for_node); + } + parity_ = std::move(parity_before); stage_ = std::move(stage_before); num_stages_ = num_stages_before; + pipeline_info_ = pipeline_info_before; // remove pipeline annotation auto for_node = result.as(); if (result.as()) { auto for_node = Downcast(result); for_node.CopyOnWrite()->annotations.erase("num_stages"); - if (is_emitting_producer_) { + if (is_emitting_producer_ || group_info_array.size() == 0) { for_node.CopyOnWrite()->annotations.erase("software_pipeline_order"); for_node.CopyOnWrite()->annotations.erase("software_pipeline_stage"); } - auto group_info_anno = op->annotations.Get("software_pipeline_group"); - if (is_emitting_producer_ || !group_info_anno.defined()) { + if (is_emitting_producer_ || !group_anno.defined() ||group_info_array.size() == 0) { return for_node; } - auto group_info = - Downcast>>(op->annotations.at("software_pipeline_group")); - GroupOpRewriter group_op_rewriter(group_info); - for_node.CopyOnWrite()->annotations.erase("software_pipeline_group"); - Stmt grouped_for_node = group_op_rewriter(for_node); return grouped_for_node; } return result; @@ -622,6 +806,7 @@ class WSCodeEmitter : public StmtMutator { PrimExpr stage_ = 0; int num_stages_ = 1; Var thread_var_; + PipelineInfo pipeline_info_; friend class WarpSpecializedRewriter; }; From 536d1e8db02dd1a881474f93e218edb088659f30 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Fri, 4 Oct 2024 03:17:09 +0000 Subject: [PATCH 23/23] [tl] Update for merge. Fix bug in thread_partial_sync. --- .gitignore | 8 +------- python/tvm/tl/language.py | 16 +++++++++++---- src/tir/transforms/thread_partial_sync.cc | 24 +++++++++++------------ 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index ca39f0a3a5f5..6b18c12e02a2 100644 --- a/.gitignore +++ b/.gitignore @@ -278,11 +278,5 @@ gallery/how_to/work_with_microtvm/micro_tvmc.py # GDB history file .gdb_history -*/reports/* -play.py *.ptx -*.ncu-rep -tl_verify/* -*/modified_code.cu -modified_code.cu -code_replace.py +*.ncu-rep \ No newline at end of file diff --git a/python/tvm/tl/language.py b/python/tvm/tl/language.py index d9e4a974ff05..6e3e78686cff 100644 --- a/python/tvm/tl/language.py +++ b/python/tvm/tl/language.py @@ -47,10 +47,10 @@ def Pipelined( start: tir.PrimExpr, stop: tir.PrimExpr = None, num_stages: int = 0, - order: List[int] = [], - stage: List[int] = [], - sync: List[List[int]] = [], - group: List[List[int]] = [] + order: List[int] = None, + stage: List[int] = None, + sync: List[List[int]] = None, + group: List[List[int]] = None ): """Tools to construct pipelined for loop. @@ -74,6 +74,14 @@ def Pipelined( start = IntImm(start.dtype, 0) else: start = 0 + if order is None: + order = [] + if stage is None: + stage = [] + if sync is None: + sync = [] + if group is None: + group = [] # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group) diff --git a/src/tir/transforms/thread_partial_sync.cc b/src/tir/transforms/thread_partial_sync.cc index a4c2fdb22a31..acc052b4c3fb 100644 --- a/src/tir/transforms/thread_partial_sync.cc +++ b/src/tir/transforms/thread_partial_sync.cc @@ -55,18 +55,18 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor { // Redirect all "shared.dyn" buffer access to the same buffer var // so that the accesses can be planned together. Var shared_dyn_buf; - // for (StmtEntry& entry : seq) { - // for (AccessEntry& access : entry.access) { - // if (access.scope.rank == StorageRank::kShared && access.scope.tag == ".dyn" && - // access.buffer.defined()) { - // if (!shared_dyn_buf.defined()) { - // shared_dyn_buf = access.buffer; - // } else { - // access.buffer = shared_dyn_buf; - // } - // } - // } - // } + for (StmtEntry& entry : seq) { + for (AccessEntry& access : entry.access) { + if (access.scope.rank == StorageRank::kShared && access.scope.tag == ".dyn" && + access.buffer.defined()) { + if (!shared_dyn_buf.defined()) { + shared_dyn_buf = access.buffer; + } else { + access.buffer = shared_dyn_buf; + } + } + } + } // Unsynced reads and writes std::vector reads;