diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 51c99fb375f3..7e40b329d290 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -690,6 +690,46 @@ class SeqStmtNode : public StmtNode { TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode); }; +/*! + * \brief Evaluates an expression. + * This is mostly used for putting a Call node into Stmt. + * + * If value do not have side-effect, this node can be safely removed. + */ +class EvaluateNode : public StmtNode { + public: + /*! \brief The expression to be evaluated. */ + PrimExpr value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("value", &value); + v->Visit("span", &span); + } + + bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const { + return equal(value, other->value); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } + + static constexpr const char* _type_key = "tir.Evaluate"; + TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode); +}; + +/*! + * \brief Managed reference to EvaluateNode. + * \sa EvaluateNode + */ +class Evaluate : public Stmt { + public: + TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span()); + + explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {} + + TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode); +}; + /*! \brief Sequence statement. */ class SeqStmt : public Stmt { public: @@ -718,6 +758,10 @@ class SeqStmt : public Stmt { * \note This function can directly return an element * if it is the only element in the sequence. * + * \note If the only argument to this function is a SeqStmt, and if + * no flattening of the SeqStmt is required, then the SeqStmt + * will be returned as-is. + * * \param seq_args The list of arguments to be flattened. * \tparam Args arguments * \return The constructed statement @@ -726,7 +770,36 @@ class SeqStmt : public Stmt { static Stmt Flatten(Args&&... seq_args) { Array seq; runtime::detail::for_each(Flattener(&seq), std::forward(seq_args)...); - if (seq.size() == 1) return seq[0]; + + if (seq.empty()) { + return Evaluate(0); + } else if (seq.size() == 1) { + return seq[0]; + } + + // If the argument is a single SeqStmt argument with no + // flattening or unwrapping required required, then we may + // return the SeqStmt as-is. + if constexpr (sizeof...(seq_args) == 1) { + if (auto opt = Flattener::AsSeqStmt(std::forward(seq_args)...)) { + SeqStmt original = opt.value(); + bool all_same = [&]() { + if (original->seq.size() != seq.size()) { + return false; + } + for (size_t i = 0; i < seq.size(); i++) { + if (!original->seq[i].same_as(seq[i])) { + return false; + } + } + return true; + }(); + if (all_same) { + return original; + } + } + } + return SeqStmt(seq); } /*! \brief Helper class to flatten sequence of arguments into Array. */ @@ -734,26 +807,56 @@ class SeqStmt : public Stmt { public: explicit Flattener(Array* seq) : seq_(seq) {} + template + static Optional AsSeqStmt(const T& t) { + if constexpr (std::is_same_v) { + return t; + } else if constexpr (!std::is_base_of_v) { + return NullOpt; + } else if (auto* ptr = t.template as()) { + return GetRef(ptr); + } else { + return NullOpt; + } + } + template void operator()(size_t i, const T& stmt_or_seq) const { if constexpr (std::is_base_of_v) { // Early bail-out, applicable to any ObjectRef - if (!stmt_or_seq.defined()) return; + if (!stmt_or_seq.defined()) { + return; + } } if constexpr (std::is_same_v) { - // No need for dynamic type-checking if the static type is a - // SeqStmt. + // Static type-checking for a SeqStmt that could be flattened. (*this)(0, stmt_or_seq->seq); - } else if constexpr (std::is_base_of_v) { + return; + } + + if constexpr (std::is_base_of_v) { // Dynamic type-checking for a SeqStmt that could be // flattened. if (auto* op = stmt_or_seq.template as()) { operator()(0, op->seq); - } else { - seq_->push_back(stmt_or_seq); + return; } - } else if constexpr (std::is_base_of_v) { + } + + if constexpr (std::is_base_of_v) { + // Evaluate(0) is used to represent a no-op, and may be + // generated by previous calls to SeqStmt::Flatten(). These + // should be removed to ensure that Flatten(a+b) is equivalent + // to Flatten(Flatten(a), Flatten(b)). + if (auto* op = stmt_or_seq.template as()) { + if (auto* as_int = op->value.template as(); as_int && as_int->value == 0) { + return; + } + } + } + + if constexpr (std::is_base_of_v) { // Any other Stmt type just gets appended. seq_->push_back(stmt_or_seq); } else { @@ -819,46 +922,6 @@ class IfThenElse : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode); }; -/*! - * \brief Evaluates an expression. - * This is mostly used for putting a Call node into Stmt. - * - * If value do not have side-effect, this node can be safely removed. - */ -class EvaluateNode : public StmtNode { - public: - /*! \brief The expression to be evaluated. */ - PrimExpr value; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("value", &value); - v->Visit("span", &span); - } - - bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const { - return equal(value, other->value); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } - - static constexpr const char* _type_key = "tir.Evaluate"; - TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode); -}; - -/*! - * \brief Managed reference to EvaluateNode. - * \sa EvaluateNode - */ -class Evaluate : public Stmt { - public: - TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span()); - - explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {} - - TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode); -}; - /*! * \brief The kind of the loop. * diff --git a/src/relay/backend/aot/aot_lower_main.cc b/src/relay/backend/aot/aot_lower_main.cc index 5688d22c1ba3..c7752c08053f 100644 --- a/src/relay/backend/aot/aot_lower_main.cc +++ b/src/relay/backend/aot/aot_lower_main.cc @@ -417,7 +417,7 @@ class AOTMainLowerer : public MixedModeVisitor { * runner function needs to be legalized by the LegalizePackedCalls pass. */ tir::PrimFunc CreateMainFunc(String mod_name) { - tir::Stmt body = tir::SeqStmt(stmts_); + tir::Stmt body = tir::SeqStmt::Flatten(stmts_); // Allocate the sids std::unordered_map allocated; std::vector> sids_to_allocate; @@ -674,7 +674,7 @@ class AOTMainLowerer : public MixedModeVisitor { })); } - tir::Stmt body = tir::SeqStmt({func_call}); + tir::Stmt body = tir::SeqStmt::Flatten(func_call); stmts_.push_back(body); } @@ -717,7 +717,7 @@ class AOTMainLowerer : public MixedModeVisitor { {tvm::tir::StringImm(device_hook_name), context}))); device_hooks.push_back(device_hook); } - return tir::SeqStmt(device_hooks); + return tir::SeqStmt::Flatten(device_hooks); } /*! diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 77765bacd6e0..1261d9971762 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -494,7 +494,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { })); } - tir::Stmt body = tir::SeqStmt({func_call}); + tir::Stmt body = tir::SeqStmt::Flatten(func_call); stmts_.push_back(body); } @@ -570,7 +570,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { {tvm::tir::StringImm(device_hook_name), context}))); device_hooks.push_back(device_hook); } - return tir::SeqStmt(device_hooks); + return tir::SeqStmt::Flatten(device_hooks); } /** @@ -736,7 +736,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { // the packed function calls don't pack their arguments. The AOT // runner function needs to be legalized by the LegalizePackedCalls pass. tir::PrimFunc CreateMainFunc(String mod_name, unsigned int relay_params) { - tir::Stmt body = tir::SeqStmt(stmts_); + tir::Stmt body = tir::SeqStmt::Flatten(stmts_); // Allocate the sids std::unordered_map allocated; diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index f3b547532cfd..9703a2adc323 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -51,14 +51,7 @@ inline void AddToParent(tvm::tir::Stmt stmt) { * \return The SeqStmt. */ inline tvm::tir::Stmt AsStmt(const Array& stmt) { - using namespace tvm::tir; - if (stmt.empty()) { - return tvm::tir::Evaluate(0); - } else if (stmt.size() == 1) { - return stmt[0]; - } else { - return SeqStmt(stmt); - } + return tvm::tir::SeqStmt::Flatten(stmt); } /*! diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index 369c4adc8536..5968febe9af0 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -144,7 +144,7 @@ class HoistAllocatesMutator : public StmtExprMutator { for (auto it = allocates_.rbegin(); it != allocates_.rend(); it++) { Allocate current_alloc = *it; if (it != allocates_.rbegin()) { - new_main_func_body = SeqStmt({new_main_func_body}); + new_main_func_body = SeqStmt::Flatten(new_main_func_body); } new_main_func_body = Allocate(current_alloc->buffer_var, current_alloc->dtype, current_alloc->extents, diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index e4569898f7ed..b32b9b6c4584 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -387,6 +387,25 @@ TVM_REGISTER_NODE_TYPE(PrefetchNode); // SeqStmt SeqStmt::SeqStmt(Array seq, Span span) { + bool requires_flattening = std::any_of( + seq.begin(), seq.end(), [](const Stmt& stmt) { return stmt->IsInstance(); }); + + if (requires_flattening) { + auto flattened = SeqStmt::Flatten(seq); + if (auto* ptr = flattened.as()) { + seq = ptr->seq; + } else { + seq = {flattened}; + } + } + + ICHECK_NE(seq.size(), 0) << "An empty SeqStmt is prohibited. " + << "To write a no-op, use Evaluate(0), " + << "or the result of SeqStmt::Flatten()"; + ICHECK_NE(seq.size(), 1) << "A SeqStmt of length 1 is prohibited. " + << "Use the node " << seq[0] << "directly, " + << "or for dynamic usage, normalize using SeqStmt::Flatten()"; + auto node = make_object(); node->seq = std::move(seq); node->span = std::move(span); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 075bbcd3ace1..ae797c0d791b 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -437,11 +437,11 @@ Stmt StmtMutator::VisitStmt_(const PrefetchNode* op) { Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { Array seq = Internal::Mutate(this, op->seq); if (seq.same_as(op->seq)) { - return GetRef(op); + return SeqStmt::Flatten(GetRef(op)); } else { - auto n = CopyOnWrite(op); - n->seq = std::move(seq); - return Stmt(n); + auto node = CopyOnWrite(op); + node->seq = std::move(seq); + return SeqStmt::Flatten(SeqStmt(node)); } } diff --git a/src/tir/schedule/analysis/reducer.cc b/src/tir/schedule/analysis/reducer.cc index ed59fe645026..d8d1e8fc2572 100644 --- a/src/tir/schedule/analysis/reducer.cc +++ b/src/tir/schedule/analysis/reducer.cc @@ -394,16 +394,15 @@ void ExtractReductionUpdates(const Optional& self, Block block, if (p_seq == nullptr && p_buf_store == nullptr) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/5); } - SeqStmt seq = - p_seq != nullptr ? GetRef(p_seq) : SeqStmt({GetRef(p_buf_store)}); - if (static_cast(seq->seq.size()) != n_buffers) { + Array seq = p_seq != nullptr ? p_seq->seq : Array{GetRef(p_buf_store)}; + if (static_cast(seq.size()) != n_buffers) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/6); } // Step 2. // - Create BufferStores according to the variables being stored. // - Construct the mapping from reduction buffers to the index. - for (const Stmt& stmt : seq->seq) { + for (const Stmt& stmt : seq) { const auto* buf_store = stmt.as(); if (buf_store == nullptr) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/5); diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index d35cf8b8d602..4179b00a3684 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -178,31 +178,6 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { } } - Stmt VisitStmt_(const SeqStmtNode* op) final { - auto ret = Downcast(StmtMutator::VisitSeqStmt_(op, true)); - - bool need_compact = std::any_of(ret->seq.begin(), ret->seq.end(), - [](const auto& stmt) { return is_no_op(stmt); }); - - if (need_compact) { - Array filtered; - for (Stmt stmt : ret->seq) { - if (!is_no_op(stmt)) { - filtered.push_back(std::move(stmt)); - } - } - ret = SeqStmt(filtered); - } - - if (ret->size() == 0) { - return Evaluate(0); - } else if (ret->size() == 1) { - return ret->seq[0]; - } else { - return std::move(ret); - } - } - Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = GetRef(op); diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 2909915c3288..30b1bc78247a 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -280,7 +280,6 @@ TEST(IRF, StmtMutator) { auto* ref2 = body2.get(); auto* extentptr = body.as()->extents.get(); // construct a recursive SeqStmt. - body = SeqStmt({body}); body = SeqStmt({body, body2}); body = SeqStmt({body, body2}); body = v(std::move(body)); @@ -296,7 +295,7 @@ TEST(IRF, StmtMutator) { Stmt body2 = Evaluate(1); auto* extentptr = body.as()->extents.get(); // construct a recursive SeqStmt. - body = SeqStmt({body}); + body = SeqStmt({body, body2}); auto bref = body; body = SeqStmt({body, body2}); body = v(std::move(body)); diff --git a/tests/python/relay/aot/test_c_device_api.py b/tests/python/relay/aot/test_c_device_api.py index 247b22eac494..e317d059f053 100644 --- a/tests/python/relay/aot/test_c_device_api.py +++ b/tests/python/relay/aot/test_c_device_api.py @@ -238,7 +238,7 @@ def test_without_device_api_unpacked_api(non_device_api_main_func): """Test a graph without the Device API with the unpacked internal calls""" main_func = non_device_api_main_func(interface_api="c", use_unpacked_api=True) - body = main_func.body.seq[1].seq[0].seq[0].value + body = main_func.body.value assert ( repr(body) == 'T.tvm_check_return(0, -1, T.call_extern("int32", ' @@ -252,7 +252,7 @@ def test_without_device_api_packed_api(non_device_api_main_func): main_func = non_device_api_main_func(interface_api="packed", use_unpacked_api=False) - body = main_func.body.seq[1].seq[0].seq[0].value + body = main_func.body.value assert repr(body) == ( 'T.call_cpacked("tvmgen_default_fused_multiply", ' "T.tvm_stack_make_array(x_buffer_var, T.tvm_stack_make_shape(10, 10), " diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 1eb34b07d7ab..f7e5af18d20e 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -1042,7 +1042,7 @@ def test_aot_codegen_checks_returns(): main_func = main_ir_module["__tvm_main__"] # Check operator call is wrapped properly - body = main_func.body[1].seq[0].seq[0].value + body = main_func.body.value assert ( repr(body) == 'T.tvm_check_return(0, -1, T.call_extern("int32", "tvmgen_default_fused_add",' diff --git a/tests/python/unittest/test_tir_stmt_functor_ir_transform.py b/tests/python/unittest/test_tir_stmt_functor_ir_transform.py index c79ba68481f6..a22e5bf466ee 100644 --- a/tests/python/unittest/test_tir_stmt_functor_ir_transform.py +++ b/tests/python/unittest/test_tir_stmt_functor_ir_transform.py @@ -31,7 +31,7 @@ def test_ir_transform(): def preorder(op): if op.op.same_as(builtin_call_extern) and op.args[0].value == "TestC": - return tvm.tir.const(0, "int32") + return tvm.tir.const(42, "int32") return None def postorder(op): @@ -43,7 +43,7 @@ def postorder(op): body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["tir.Call"]) stmt_list = tvm.tir.stmt_list(body.body.body) assert stmt_list[0].value.args[1].args[0].value == "TestB" - assert stmt_list[1].value.value == 0 + assert stmt_list[1].value.value == 42 if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_remove_no_op.py b/tests/python/unittest/test_tir_transform_remove_no_op.py index 15c5a577f9f5..133ef01ed001 100644 --- a/tests/python/unittest/test_tir_transform_remove_no_op.py +++ b/tests/python/unittest/test_tir_transform_remove_no_op.py @@ -23,7 +23,7 @@ def nop(): - return tvm.tir.Evaluate(0) + return tvm.tir.Evaluate(1) def test_remove_no_op(): diff --git a/tests/python/unittest/test_tvmscript_printer_annotation.py b/tests/python/unittest/test_tvmscript_printer_annotation.py index 70d5b655fb37..72d2238b2b63 100644 --- a/tests/python/unittest/test_tvmscript_printer_annotation.py +++ b/tests/python/unittest/test_tvmscript_printer_annotation.py @@ -24,7 +24,7 @@ @T.prim_func def _func(): - T.evaluate(0) + T.evaluate(-1) T.evaluate(1) T.evaluate(2) T.evaluate(3) @@ -49,7 +49,7 @@ def test_annotation_multi_object_paths(): @T.prim_func def main(): - T.evaluate(0) + T.evaluate(-1) T.evaluate(1) # annotation 1 T.evaluate(2) T.evaluate(3) # annotation 3 @@ -75,7 +75,7 @@ def test_annotate_from_multi_obj(): @T.prim_func def main(): - T.evaluate(0) + T.evaluate(-1) T.evaluate(1) # annotation 1 T.evaluate(2) T.evaluate(3) # annotation 3 diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 25272d912da2..8427754db71e 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -396,14 +396,14 @@ def test_prefetch(): def test_seq_stmt(): with IRBuilder() as ib: with T.serial(10): - T.evaluate(0) T.evaluate(1) + T.evaluate(2) obj = ib.get().body _assert_print( obj, """ -T.evaluate(0) T.evaluate(1) +T.evaluate(2) """, ) diff --git a/tests/python/unittest/test_tvmscript_printer_underlining.py b/tests/python/unittest/test_tvmscript_printer_underlining.py index 4a4d17d0d89b..569f03d0f828 100644 --- a/tests/python/unittest/test_tvmscript_printer_underlining.py +++ b/tests/python/unittest/test_tvmscript_printer_underlining.py @@ -433,7 +433,7 @@ def main(a: T.int32, b: T.int32): def test_underline_from_multi_obj(): @T.prim_func def func(): - T.evaluate(0) + T.evaluate(-1) T.evaluate(1) T.evaluate(2) T.evaluate(3) @@ -456,7 +456,7 @@ def func(): @T.prim_func def main(): - T.evaluate(0) + T.evaluate(-1) T.evaluate(1) ^^^^^^^^^^^^^ T.evaluate(2) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index beba76152a39..757f74ab8396 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3722,19 +3722,6 @@ def func(A: T.Buffer(64, "float32")): return tvm.tir.transform.MakePackedAPI()(mod) -def ir_module_with_attrs(): - @I.ir_module - class Module: - I.module_attrs({"attr": 10}) - - @T.prim_func - def tir_func(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")): - for i in range(16): - B[i] = A[i] - - return Module - - def tvm_struct_set_generated_in_cpp(): """Ensure same dtype for tvm_struct_set in Python/C++ @@ -3768,6 +3755,43 @@ def tir_packed_call(A: T.Buffer(16)): return tvm.tir.transform.LowerTVMBuiltin()(Module) +def ir_module_with_attrs(): + @I.ir_module + class Module: + I.module_attrs({"attr": 10}) + + @T.prim_func + def tir_func(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")): + for i in range(16): + B[i] = A[i] + + return Module + + +def nested_seqstmt(): + """Nested SeqStmt should be normalized to flat SeqStmt + + Nested SeqStmt are representable in the TIR structures, but are + flattened when converted to TVMScript. Previously, this could + cause failures to round-trip through TVMScript, including + erroneous use of TVMScript's concise-scoping rules. This was + resolved by normalizing nested SeqStmt in TIR, such that the use + of `tir.SeqStmt` below results in a single flat `tir.SeqStmt` + containing the three `tir.Evaluate` calls. + """ + func = tvm.tir.PrimFunc( + params=[], + body=tvm.tir.SeqStmt( + [ + tvm.tir.SeqStmt([tvm.tir.Evaluate(0), tvm.tir.Evaluate(1)]), + tvm.tir.Evaluate(2), + ] + ), + ) + + return func + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -3834,8 +3858,9 @@ def tir_packed_call(A: T.Buffer(16)): if_then_else_var, tvm_shfl_builtins, make_packed_api_result, - ir_module_with_attrs, tvm_struct_set_generated_in_cpp, + ir_module_with_attrs, + nested_seqstmt, )