Skip to content

Commit

Permalink
[TIR] Preserve AllocateNode::annotations (apache#15242)
Browse files Browse the repository at this point in the history
Prior to this commit, some lowering passes would erroneously strip out
the annotations from `Allocate` nodes.  This commit updates these
passes to preserve the annotations where present.
  • Loading branch information
Lunderberg authored and junrushao committed Jul 15, 2023
1 parent 761ca3b commit 88ba4ea
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/tir/transforms/inject_double_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ class DoubleBufferInjector : public StmtExprMutator {
Array<PrimExpr> new_extents = {op->extents[0] * make_const(op->extents[0].dtype(), 2)};
ICHECK(entry.loop != nullptr);
auto& alloc_nest = loop_allocs_[entry.loop];
alloc_nest.emplace_back(
Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0)));
alloc_nest.emplace_back(Allocate(op->buffer_var, op->dtype, new_extents, op->condition,
Evaluate(0), op->annotations));
Stmt body = op->body;
if (auto ptr = body.as<DeclBufferNode>()) {
auto new_buf = GetRemappedBuffer(ptr->buffer, entry.stride);
Expand Down
3 changes: 2 additions & 1 deletion src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ class IRConvertSSA final : public StmtExprMutator {
ScopedRedefine redefine(this, v);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
return Allocate(redefine.new_var, op->dtype, op->extents, op->condition, op->body);
return Allocate(redefine.new_var, op->dtype, op->extents, op->condition, op->body,
op->annotations);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op);
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/lower_custom_datatypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class CustomDatatypesLowerer : public StmtExprMutator {
allocate = stmt.as<AllocateNode>();

return Allocate(new_buffer_var, new_allocate_type, allocate->extents, allocate->condition,
allocate->body);
allocate->body, allocate->annotations);
} else {
return StmtExprMutator::VisitStmt_(allocate);
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScop
// use volatile access to shared buffer.
body = AttrStmt(remapped, attr::volatile_scope, 1, body);
}
return Allocate(remapped, op->dtype, op->extents, op->condition, body);
return Allocate(remapped, op->dtype, op->extents, op->condition, body, op->annotations);
}
return StmtExprMutator::VisitStmt_(op);
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
alloc_size = warp_group_ * factor;

return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / width_)},
op->condition, this->VisitStmt(op->body));
op->condition, this->VisitStmt(op->body), op->annotations);
}

protected:
Expand Down
9 changes: 6 additions & 3 deletions src/tir/transforms/update_pointer_storage_scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/tir/transform.h>

#include <unordered_map>
#include <utility>

#include "../../runtime/thread_storage_scope.h"
#include "ir_utils.h"
Expand Down Expand Up @@ -59,9 +60,11 @@ PrimExpr UpdatePointerStorageScope::VisitExpr_(const VarNode* op) {
}

Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) {
auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr(op->buffer_var));
return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition),
StmtExprMutator::VisitStmt(op->body));
auto node = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
if (auto it = new_var_remap_.find(node->buffer_var.get()); it != new_var_remap_.end()) {
node.CopyOnWrite()->buffer_var = it->second;
}
return std::move(node);
}

template <typename Node>
Expand Down

0 comments on commit 88ba4ea

Please sign in to comment.