diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 88188425a9e8..4e2e79db26da 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -119,8 +119,8 @@ class DoubleBufferInjector : public StmtExprMutator { Array new_extents = {op->extents[0] * make_const(op->extents[0].dtype(), 2)}; ICHECK(entry.loop != nullptr); auto& alloc_nest = loop_allocs_[entry.loop]; - alloc_nest.emplace_back( - Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0))); + alloc_nest.emplace_back(Allocate(op->buffer_var, op->dtype, new_extents, op->condition, + Evaluate(0), op->annotations)); Stmt body = op->body; if (auto ptr = body.as()) { auto new_buf = GetRemappedBuffer(ptr->buffer, entry.stride); diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 43bf6b983eb5..99ed4376590e 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -335,7 +335,8 @@ class IRConvertSSA final : public StmtExprMutator { ScopedRedefine redefine(this, v); Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - return Allocate(redefine.new_var, op->dtype, op->extents, op->condition, op->body); + return Allocate(redefine.new_var, op->dtype, op->extents, op->condition, op->body, + op->annotations); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index c5bcda2effc9..273d37829dcb 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -97,7 +97,7 @@ class CustomDatatypesLowerer : public StmtExprMutator { allocate = stmt.as(); return Allocate(new_buffer_var, new_allocate_type, allocate->extents, allocate->condition, - allocate->body); + allocate->body, allocate->annotations); } else { return StmtExprMutator::VisitStmt_(allocate); } diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index f6cda51f43ad..c1566936c531 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -53,7 +53,7 @@ class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScop // use volatile access to shared buffer. body = AttrStmt(remapped, attr::volatile_scope, 1, body); } - return Allocate(remapped, op->dtype, op->extents, op->condition, body); + return Allocate(remapped, op->dtype, op->extents, op->condition, body, op->annotations); } return StmtExprMutator::VisitStmt_(op); } diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 571f512bfd14..870235954689 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -249,7 +249,7 @@ class WarpAccessRewriter : protected StmtExprMutator { alloc_size = warp_group_ * factor; return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / width_)}, - op->condition, this->VisitStmt(op->body)); + op->condition, this->VisitStmt(op->body), op->annotations); } protected: diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc index 18950bc1997f..2049487b4a78 100644 --- a/src/tir/transforms/update_pointer_storage_scope.cc +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -29,6 +29,7 @@ #include #include +#include #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" @@ -59,9 +60,11 @@ PrimExpr UpdatePointerStorageScope::VisitExpr_(const VarNode* op) { } Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { - auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); - return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), - StmtExprMutator::VisitStmt(op->body)); + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + if (auto it = new_var_remap_.find(node->buffer_var.get()); it != new_var_remap_.end()) { + node.CopyOnWrite()->buffer_var = it->second; + } + return std::move(node); } template