Skip to content

Commit

Permalink
[TIR] Add pass to replace global to shared memory load with cp async
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 25, 2022
1 parent 82ad1d4 commit 9c1126d
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 20 deletions.
1 change: 1 addition & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ TVM_DLL Pass LowerAutoCopy();
*/
TVM_DLL Pass RenormalizeSplitPattern();

TVM_DLL Pass InjectAsyncCopy();
} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn"));
mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
mixed_pass_list.push_back(tir::transform::InjectAsyncCopy());
mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());

Expand Down
22 changes: 11 additions & 11 deletions src/tir/transforms/coproc_sync.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
}

// Plan the sync
std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final {
return PlanSync(seq, loop, false);
std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const StmtNode* parent) final {
return PlanSync(seq, parent, false);
}

private:
// Plan write synchronization if write is not coherent
std::vector<AccessEntry> PlanSync(std::vector<StmtEntry> seq, const ForNode* loop,
std::vector<AccessEntry> PlanSync(std::vector<StmtEntry> seq, const StmtNode* parent,
bool force_sync_at_end) {
// detect write barriers
// access by the co-processor.
Expand Down Expand Up @@ -167,7 +167,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
}
}
bool sync_at_end = force_sync_at_end;
if (loop != nullptr && !sync_at_end) {
if (parent->IsInstance<ForNode>() && !sync_at_end) {
// loop carray dependency
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
Expand Down Expand Up @@ -239,17 +239,17 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
}

// Plan the sync
std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final {
std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const StmtNode* parent) final {
if (read_barrier_) {
return PlanReadBarrier(seq, loop);
return PlanReadBarrier(seq, parent);
} else {
return PlanWriteBarrier(seq, loop);
return PlanWriteBarrier(seq, parent);
}
}

private:
// Plan write barrier at Read after write point.
std::vector<AccessEntry> PlanWriteBarrier(std::vector<StmtEntry> seq, const ForNode* loop) {
std::vector<AccessEntry> PlanWriteBarrier(std::vector<StmtEntry> seq, const StmtNode* parent) {
std::vector<AccessEntry> read_seq;
std::unordered_map<const VarNode*, std::vector<AccessEntry> > write_set;

Expand All @@ -276,7 +276,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
}
}
// loop carry
if (loop != nullptr) {
if (parent->IsInstance<ForNode>()) {
for (const AccessEntry& acc : read_seq) {
fupdate(seq.size(), acc);
}
Expand All @@ -287,7 +287,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
return read_seq;
}

std::vector<AccessEntry> PlanReadBarrier(std::vector<StmtEntry> seq, const ForNode* loop) {
std::vector<AccessEntry> PlanReadBarrier(std::vector<StmtEntry> seq, const StmtNode* parent) {
std::vector<AccessEntry> write_seq;
std::unordered_map<const VarNode*, std::vector<AccessEntry> > read_set;

Expand Down Expand Up @@ -315,7 +315,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
}
}
// loop carry
if (loop != nullptr) {
if (parent->IsInstance<ForNode>()) {
for (const AccessEntry& acc : write_seq) {
fupdate(0, acc);
}
Expand Down
153 changes: 153 additions & 0 deletions src/tir/transforms/inject_async_copy.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief Replace copy from global to shared with async copy
* \file inject_async_copy.cc
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "storage_access.h"

namespace tvm {
namespace tir {

class AsyncCopyOnjector : public StmtMutator {
public:
std::unordered_set<const VarNode*> pending_async_copy_;

Stmt VisitStmt_(const BufferStoreNode* store) {
if (store->buffer.scope() == "shared") {
if (auto* load = store->value.as<BufferLoadNode>()) {
if (load->buffer.scope() == "global") {
ICHECK(load->indices.size() == 1);
const int bytes = load->indices[0]->dtype.lanes() * load->buffer->dtype.bytes();
if (bytes == 4 || bytes == 8 || bytes == 16) {
auto dst_offset = store->indices[0].as<RampNode>()->base;
auto src_offset = load->indices[0].as<RampNode>()->base;
pending_async_copy_.insert(store->buffer->data.get());
return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, dst_offset, load->buffer->data, src_offset,
PrimExpr(bytes)}));
}
}
}
}
return StmtMutator::VisitStmt_(store);
}

Stmt VisitStmt_(const SeqStmtNode* store) { return StmtMutator::VisitStmt_(store); }
};

class InsertWaitGroupPlanner : public StorageAccessVisitor {
public:
explicit InsertWaitGroupPlanner(const std::unordered_set<const VarNode*>& pending_async_copy)
: pending_async_copy_{pending_async_copy} {}

std::unordered_set<const Object*> insert_wait_before_;

protected:
bool Enabled(const VarNode* buf, const StorageScope& scope) const final {
return scope == StorageScope::Create("shared") && pending_async_copy_.count(buf) != 0;
}

std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const StmtNode* parent) final {
std::vector<AccessEntry> flattened;
std::vector<AccessEntry> pending_writes;

for (const StmtEntry& s : seq) {
bool wait_before_stmt = false;
for (const AccessEntry& acc : s.access) {
ICHECK(pending_async_copy_.count(acc.buffer.get()) != 0);
if (acc.type == kRead) {
if (FindConflict(pending_writes, acc)) {
wait_before_stmt = true;
break;
}
} else if (acc.type == kWrite) {
pending_writes.push_back(acc);
}
flattened.push_back(acc);
}
if (wait_before_stmt) {
ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition";
insert_wait_before_.insert(s.stmt);
}
}
return flattened;
}

private:
bool FindConflict(const std::vector<AccessEntry>& pending_writes, const AccessEntry& read) {
for (const AccessEntry& pending_write : pending_writes) {
if (pending_write.buffer == read.buffer) {
return true;
}
}
return false;
}

const std::unordered_set<const VarNode*> pending_async_copy_;
};

class InsertWaitGroup : public StmtMutator {
public:
explicit InsertWaitGroup(std::unordered_set<const Object*> insert_wait_before)
: insert_wait_before_(insert_wait_before) {}

Stmt VisitStmt(const Stmt& stmt) final {
if (insert_wait_before_.count(stmt.get())) {
auto commit_group =
Evaluate(Call(DataType::Void(), tvm::tir::builtin::ptx_commit_group(), {}));
auto wait_group =
Evaluate(Call(DataType::Void(), tvm::tir::builtin::ptx_wait_group(), {PrimExpr(0)}));
auto ret = StmtMutator::VisitStmt(stmt);
return SeqStmt({commit_group, wait_group, ret});
} else {
return StmtMutator::VisitStmt(stmt);
}
}

std::unordered_set<const Object*> insert_wait_before_;
};

namespace transform {

Pass InjectAsyncCopy() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
AsyncCopyOnjector copy_inject;
auto cp_async_injected = copy_inject(n->body);
n->body = cp_async_injected;
LOG(INFO) << f;
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InjectAsyncCopy", {});
}

TVM_REGISTER_GLOBAL("tir.transform.InjectAsyncCopy").set_body_typed(InjectAsyncCopy);

} // namespace transform

} // namespace tir
} // namespace tvm
29 changes: 24 additions & 5 deletions src/tir/transforms/storage_access.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) {
StmtExprVisitor::VisitStmt_(op);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
s.access = Summarize(std::move(scope_.back()), op);
scope_.pop_back();
if (!s.access.empty()) {
for (AccessEntry& e : s.access) {
Expand All @@ -134,7 +134,7 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) {
scope_.push_back(std::vector<StmtEntry>());
StmtExprVisitor::VisitStmt_(op);
// no need to take the result as the thread barrier automatically syncs.
Summarize(std::move(scope_.back()), nullptr);
Summarize(std::move(scope_.back()), op);
in_device_env_ = false;
scope_.pop_back();
} else {
Expand Down Expand Up @@ -185,12 +185,12 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) {
this->VisitStmt(op->then_case);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
s.access = Summarize(std::move(scope_.back()), op);
scope_.pop_back();
if (op->else_case.defined()) {
scope_.push_back(std::vector<StmtEntry>());
this->VisitStmt(op->else_case);
auto v = Summarize(std::move(scope_.back()), nullptr);
auto v = Summarize(std::move(scope_.back()), op);
scope_.pop_back();
s.access.insert(s.access.end(), v.begin(), v.end());
}
Expand All @@ -205,7 +205,7 @@ void StorageAccessVisitor::VisitStmt_(const WhileNode* op) {
this->VisitStmt(op->body);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
s.access = Summarize(std::move(scope_.back()), op);
scope_.pop_back();
scope_.back().emplace_back(std::move(s));
--condition_counter_;
Expand Down Expand Up @@ -253,6 +253,25 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) {
e.scope = StorageScope::Create(s);
curr_stmt_.access.emplace_back(std::move(e));
}
} else if (op->op.same_as(builtin::ptx_cp_async())) {
ICHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op;

Var buf = Downcast<Var>(op->args[0]);
StorageScope scope = GetScope(buf);
AccessEntry e;
e.threads = env_threads();
e.buffer = buf;
e.dtype = op->dtype;
e.touched.push_back(arith::IntSet::Vector(op->args[1]));
e.type = kWrite;
e.scope = scope;
curr_stmt_.access.emplace_back(std::move(e));

// push to the scope
scope_.back().push_back(curr_stmt_);
// clear access entry.
curr_stmt_.access.clear();
} else {
StmtExprVisitor::VisitExpr_(op);
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/storage_access.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class StorageAccessVisitor : public StmtExprVisitor {
* \return The summarized sequence that represent access that
* the parent should taken care of to synchronize.
*/
virtual std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) = 0;
virtual std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const StmtNode* loop) = 0;
/*!
* \brief Get the scope of the buffer array.
* \return The scope of the final buffer array.
Expand Down
6 changes: 3 additions & 3 deletions src/tir/transforms/thread_storage_sync.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
return in_device_env() && scope == sync_scope_;
}
// Plan the sync
std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final {
std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const StmtNode* parent) final {
// Unsynced reads and writes
std::vector<AccessEntry> reads;
std::vector<AccessEntry> writes;
Expand Down Expand Up @@ -101,7 +101,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
syncs_inserted_.insert(s.stmt);
}
}
if (loop != nullptr) {
if (parent->IsInstance<ForNode>()) {
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
if (syncs_inserted_.count(s.stmt) != 0) break;
Expand Down Expand Up @@ -166,7 +166,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
}
}
head.insert(head.end(), tail.begin(), tail.end());
if (loop != nullptr) {
if (parent->IsInstance<ForNode>()) {
// clear double buffer flag after a loop is finished.
for (AccessEntry& e : head) {
e.double_buffer_write = false;
Expand Down

0 comments on commit 9c1126d

Please sign in to comment.