diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 9254f31c6e1b..c3891c34aa02 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -630,6 +630,7 @@ TVM_DLL Pass LowerAutoCopy(); */ TVM_DLL Pass RenormalizeSplitPattern(); +TVM_DLL Pass InjectAsyncCopy(); } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index b6dbd9c0152c..3e32c7198307 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -592,6 +592,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn")); mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); + mixed_pass_list.push_back(tir::transform::InjectAsyncCopy()); mixed_pass_list.push_back(tir::transform::InferFragment()); mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index f3a9f990599f..f411c1c00f37 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -119,13 +119,13 @@ class CoProcSyncPlanner : public StorageAccessVisitor { } // Plan the sync - std::vector Summarize(std::vector seq, const ForNode* loop) final { - return PlanSync(seq, loop, false); + std::vector Summarize(std::vector seq, const StmtNode* parent) final { + return PlanSync(seq, parent, false); } private: // Plan write synchronization if write is not coherent - std::vector PlanSync(std::vector seq, const ForNode* loop, + std::vector PlanSync(std::vector seq, const StmtNode* parent, bool force_sync_at_end) { // detect write barriers // access by the co-processor. @@ -167,7 +167,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor { } } bool sync_at_end = force_sync_at_end; - if (loop != nullptr && !sync_at_end) { + if (parent->IsInstance() && !sync_at_end) { // loop carray dependency for (size_t i = 0; i < seq.size(); ++i) { const StmtEntry& s = seq[i]; @@ -239,17 +239,17 @@ class CoProcBarrierDetector : public StorageAccessVisitor { } // Plan the sync - std::vector Summarize(std::vector seq, const ForNode* loop) final { + std::vector Summarize(std::vector seq, const StmtNode* parent) final { if (read_barrier_) { - return PlanReadBarrier(seq, loop); + return PlanReadBarrier(seq, parent); } else { - return PlanWriteBarrier(seq, loop); + return PlanWriteBarrier(seq, parent); } } private: // Plan write barrier at Read after write point. - std::vector PlanWriteBarrier(std::vector seq, const ForNode* loop) { + std::vector PlanWriteBarrier(std::vector seq, const StmtNode* parent) { std::vector read_seq; std::unordered_map > write_set; @@ -276,7 +276,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor { } } // loop carry - if (loop != nullptr) { + if (parent->IsInstance()) { for (const AccessEntry& acc : read_seq) { fupdate(seq.size(), acc); } @@ -287,7 +287,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor { return read_seq; } - std::vector PlanReadBarrier(std::vector seq, const ForNode* loop) { + std::vector PlanReadBarrier(std::vector seq, const StmtNode* parent) { std::vector write_seq; std::unordered_map > read_set; @@ -315,7 +315,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor { } } // loop carry - if (loop != nullptr) { + if (parent->IsInstance()) { for (const AccessEntry& acc : write_seq) { fupdate(0, acc); } diff --git a/src/tir/transforms/inject_async_copy.cc b/src/tir/transforms/inject_async_copy.cc new file mode 100644 index 000000000000..5bdaecfd6d2b --- /dev/null +++ b/src/tir/transforms/inject_async_copy.cc @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Replace copy from global to shared with async copy + * \file inject_async_copy.cc + */ +#include +#include +#include +#include +#include + +#include "storage_access.h" + +namespace tvm { +namespace tir { + +class AsyncCopyOnjector : public StmtMutator { + public: + std::unordered_set pending_async_copy_; + + Stmt VisitStmt_(const BufferStoreNode* store) { + if (store->buffer.scope() == "shared") { + if (auto* load = store->value.as()) { + if (load->buffer.scope() == "global") { + ICHECK(load->indices.size() == 1); + const int bytes = load->indices[0]->dtype.lanes() * load->buffer->dtype.bytes(); + if (bytes == 4 || bytes == 8 || bytes == 16) { + auto dst_offset = store->indices[0].as()->base; + auto src_offset = load->indices[0].as()->base; + pending_async_copy_.insert(store->buffer->data.get()); + return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), + {store->buffer->data, dst_offset, load->buffer->data, src_offset, + PrimExpr(bytes)})); + } + } + } + } + return StmtMutator::VisitStmt_(store); + } + + Stmt VisitStmt_(const SeqStmtNode* store) { return StmtMutator::VisitStmt_(store); } +}; + +class InsertWaitGroupPlanner : public StorageAccessVisitor { + public: + explicit InsertWaitGroupPlanner(const std::unordered_set& pending_async_copy) + : pending_async_copy_{pending_async_copy} {} + + std::unordered_set insert_wait_before_; + + protected: + bool Enabled(const VarNode* buf, const StorageScope& scope) const final { + return scope == StorageScope::Create("shared") && pending_async_copy_.count(buf) != 0; + } + + std::vector Summarize(std::vector seq, const StmtNode* parent) final { + std::vector flattened; + std::vector pending_writes; + + for (const StmtEntry& s : seq) { + bool wait_before_stmt = false; + for (const AccessEntry& acc : s.access) { + ICHECK(pending_async_copy_.count(acc.buffer.get()) != 0); + if (acc.type == kRead) { + if (FindConflict(pending_writes, acc)) { + wait_before_stmt = true; + break; + } + } else if (acc.type == kWrite) { + pending_writes.push_back(acc); + } + flattened.push_back(acc); + } + if (wait_before_stmt) { + ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition"; + insert_wait_before_.insert(s.stmt); + } + } + return flattened; + } + + private: + bool FindConflict(const std::vector& pending_writes, const AccessEntry& read) { + for (const AccessEntry& pending_write : pending_writes) { + if (pending_write.buffer == read.buffer) { + return true; + } + } + return false; + } + + const std::unordered_set pending_async_copy_; +}; + +class InsertWaitGroup : public StmtMutator { + public: + explicit InsertWaitGroup(std::unordered_set insert_wait_before) + : insert_wait_before_(insert_wait_before) {} + + Stmt VisitStmt(const Stmt& stmt) final { + if (insert_wait_before_.count(stmt.get())) { + auto commit_group = + Evaluate(Call(DataType::Void(), tvm::tir::builtin::ptx_commit_group(), {})); + auto wait_group = + Evaluate(Call(DataType::Void(), tvm::tir::builtin::ptx_wait_group(), {PrimExpr(0)})); + auto ret = StmtMutator::VisitStmt(stmt); + return SeqStmt({commit_group, wait_group, ret}); + } else { + return StmtMutator::VisitStmt(stmt); + } + } + + std::unordered_set insert_wait_before_; +}; + +namespace transform { + +Pass InjectAsyncCopy() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + AsyncCopyOnjector copy_inject; + auto cp_async_injected = copy_inject(n->body); + n->body = cp_async_injected; + LOG(INFO) << f; + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectAsyncCopy", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InjectAsyncCopy").set_body_typed(InjectAsyncCopy); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 4f19f708880c..862a7e046d85 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -110,7 +110,7 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { StmtExprVisitor::VisitStmt_(op); StmtEntry s; s.stmt = op; - s.access = Summarize(std::move(scope_.back()), nullptr); + s.access = Summarize(std::move(scope_.back()), op); scope_.pop_back(); if (!s.access.empty()) { for (AccessEntry& e : s.access) { @@ -134,7 +134,7 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { scope_.push_back(std::vector()); StmtExprVisitor::VisitStmt_(op); // no need to take the result as the thread barrier automatically syncs. - Summarize(std::move(scope_.back()), nullptr); + Summarize(std::move(scope_.back()), op); in_device_env_ = false; scope_.pop_back(); } else { @@ -185,12 +185,12 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) { this->VisitStmt(op->then_case); StmtEntry s; s.stmt = op; - s.access = Summarize(std::move(scope_.back()), nullptr); + s.access = Summarize(std::move(scope_.back()), op); scope_.pop_back(); if (op->else_case.defined()) { scope_.push_back(std::vector()); this->VisitStmt(op->else_case); - auto v = Summarize(std::move(scope_.back()), nullptr); + auto v = Summarize(std::move(scope_.back()), op); scope_.pop_back(); s.access.insert(s.access.end(), v.begin(), v.end()); } @@ -205,7 +205,7 @@ void StorageAccessVisitor::VisitStmt_(const WhileNode* op) { this->VisitStmt(op->body); StmtEntry s; s.stmt = op; - s.access = Summarize(std::move(scope_.back()), nullptr); + s.access = Summarize(std::move(scope_.back()), op); scope_.pop_back(); scope_.back().emplace_back(std::move(s)); --condition_counter_; @@ -253,6 +253,25 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { e.scope = StorageScope::Create(s); curr_stmt_.access.emplace_back(std::move(e)); } + } else if (op->op.same_as(builtin::ptx_cp_async())) { + ICHECK_EQ(curr_stmt_.access.size(), 0U); + curr_stmt_.stmt = op; + + Var buf = Downcast(op->args[0]); + StorageScope scope = GetScope(buf); + AccessEntry e; + e.threads = env_threads(); + e.buffer = buf; + e.dtype = op->dtype; + e.touched.push_back(arith::IntSet::Vector(op->args[1])); + e.type = kWrite; + e.scope = scope; + curr_stmt_.access.emplace_back(std::move(e)); + + // push to the scope + scope_.back().push_back(curr_stmt_); + // clear access entry. + curr_stmt_.access.clear(); } else { StmtExprVisitor::VisitExpr_(op); } diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index a48ee73f17fc..7b6cad72ad9c 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -118,7 +118,7 @@ class StorageAccessVisitor : public StmtExprVisitor { * \return The summarized sequence that represent access that * the parent should taken care of to synchronize. */ - virtual std::vector Summarize(std::vector seq, const ForNode* loop) = 0; + virtual std::vector Summarize(std::vector seq, const StmtNode* loop) = 0; /*! * \brief Get the scope of the buffer array. * \return The scope of the final buffer array. diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index ce3f8fd3e3ac..fa2674e9860d 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -49,7 +49,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor { return in_device_env() && scope == sync_scope_; } // Plan the sync - std::vector Summarize(std::vector seq, const ForNode* loop) final { + std::vector Summarize(std::vector seq, const StmtNode* parent) final { // Unsynced reads and writes std::vector reads; std::vector writes; @@ -101,7 +101,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor { syncs_inserted_.insert(s.stmt); } } - if (loop != nullptr) { + if (parent->IsInstance()) { for (size_t i = 0; i < seq.size(); ++i) { const StmtEntry& s = seq[i]; if (syncs_inserted_.count(s.stmt) != 0) break; @@ -166,7 +166,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor { } } head.insert(head.end(), tail.begin(), tail.end()); - if (loop != nullptr) { + if (parent->IsInstance()) { // clear double buffer flag after a loop is finished. for (AccessEntry& e : head) { e.double_buffer_write = false;