Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Flatten SeqStmt on construction #14492

Merged
merged 13 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 56 additions & 44 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,46 @@ class SeqStmtNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
};

/*!
* \brief Evaluates an expression.
* This is mostly used for putting a Call node into Stmt.
*
* If value do not have side-effect, this node can be safely removed.
*/
class EvaluateNode : public StmtNode {
public:
/*! \brief The expression to be evaluated. */
PrimExpr value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("span", &span);
}

bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
return equal(value, other->value);
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }

static constexpr const char* _type_key = "tir.Evaluate";
TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
};

/*!
* \brief Managed reference to EvaluateNode.
* \sa EvaluateNode
*/
class Evaluate : public Stmt {
public:
TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());

explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}

TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode);
};

/*! \brief Sequence statement. */
class SeqStmt : public Stmt {
public:
Expand Down Expand Up @@ -726,8 +766,13 @@ class SeqStmt : public Stmt {
static Stmt Flatten(Args&&... seq_args) {
Array<Stmt> seq;
runtime::detail::for_each(Flattener(&seq), std::forward<Args>(seq_args)...);
if (seq.size() == 1) return seq[0];
return SeqStmt(seq);
if (seq.empty()) {
return Evaluate(0);
} else if (seq.size() == 1) {
return seq[0];
} else {
return SeqStmt(seq);
}
}
/*! \brief Helper class to flatten sequence of arguments into Array. */
class Flattener {
Expand All @@ -738,9 +783,16 @@ class SeqStmt : public Stmt {
if (!stmt.defined()) return;
if (auto* op = stmt.as<SeqStmtNode>()) {
operator()(0, op->seq);
} else {
seq_->push_back(stmt);
return;
}

if (auto* op = stmt.as<EvaluateNode>()) {
if (auto* as_int = op->value.as<IntImmNode>(); as_int && as_int->value == 0) {
return;
}
}

seq_->push_back(stmt);
}

template <typename T>
Expand Down Expand Up @@ -805,46 +857,6 @@ class IfThenElse : public Stmt {
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode);
};

/*!
* \brief Evaluates an expression.
* This is mostly used for putting a Call node into Stmt.
*
* If value do not have side-effect, this node can be safely removed.
*/
class EvaluateNode : public StmtNode {
public:
/*! \brief The expression to be evaluated. */
PrimExpr value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("span", &span);
}

bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
return equal(value, other->value);
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }

static constexpr const char* _type_key = "tir.Evaluate";
TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
};

/*!
* \brief Managed reference to EvaluateNode.
* \sa EvaluateNode
*/
class Evaluate : public Stmt {
public:
TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());

explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}

TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode);
};

/*!
* \brief The kind of the loop.
*
Expand Down
6 changes: 3 additions & 3 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}));
}

tir::Stmt body = tir::SeqStmt({func_call});
tir::Stmt body = tir::SeqStmt::Flatten(func_call);
stmts_.push_back(body);
}

Expand Down Expand Up @@ -570,7 +570,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
{tvm::tir::StringImm(device_hook_name), context})));
device_hooks.push_back(device_hook);
}
return tir::SeqStmt(device_hooks);
return tir::SeqStmt::Flatten(device_hooks);
}

/**
Expand Down Expand Up @@ -736,7 +736,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// the packed function calls don't pack their arguments. The AOT
// runner function needs to be legalized by the LegalizePackedCalls pass.
tir::PrimFunc CreateMainFunc(String mod_name, unsigned int relay_params) {
tir::Stmt body = tir::SeqStmt(stmts_);
tir::Stmt body = tir::SeqStmt::Flatten(stmts_);
// Allocate the sids
std::unordered_map<int, bool> allocated;

Expand Down
19 changes: 19 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,25 @@ TVM_REGISTER_NODE_TYPE(PrefetchNode);

// SeqStmt
SeqStmt::SeqStmt(Array<Stmt> seq, Span span) {
bool requires_flattening = std::any_of(
seq.begin(), seq.end(), [](const Stmt& stmt) { return stmt->IsInstance<SeqStmtNode>(); });

if (requires_flattening) {
auto flattened = SeqStmt::Flatten(seq);
if (auto* ptr = flattened.as<SeqStmtNode>()) {
seq = ptr->seq;
} else {
seq = {flattened};
}
}

ICHECK_NE(seq.size(), 0) << "An empty SeqStmt is prohibited. "
<< "To write a no-op, use Evaluate(0), "
<< "or the result of SeqStmt::Flatten()";
ICHECK_NE(seq.size(), 1) << "A SeqStmt of length 1 is prohibited. "
<< "Use the node " << seq[0] << "directly, "
<< "or for dynamic usage, normalize using SeqStmt::Flatten()";

auto node = make_object<SeqStmtNode>();
node->seq = std::move(seq);
node->span = std::move(span);
Expand Down
4 changes: 1 addition & 3 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,7 @@ Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) {
if (seq.same_as(op->seq)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->seq = std::move(seq);
return Stmt(n);
return SeqStmt::Flatten(seq);
}
}

Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/aot/test_c_device_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_without_device_api_unpacked_api(non_device_api_main_func):
"""Test a graph without the Device API with the unpacked internal calls"""

main_func = non_device_api_main_func(interface_api="c", use_unpacked_api=True)
body = main_func.body.seq[1].seq[0].seq[0].value
body = main_func.body.value
assert (
repr(body)
== 'T.tvm_check_return(0, -1, T.call_extern("int32", '
Expand All @@ -252,7 +252,7 @@ def test_without_device_api_packed_api(non_device_api_main_func):

main_func = non_device_api_main_func(interface_api="packed", use_unpacked_api=False)

body = main_func.body.seq[1].seq[0].seq[0].value
body = main_func.body.value
assert repr(body) == (
'T.call_cpacked("tvmgen_default_fused_multiply", '
"T.tvm_stack_make_array(x_buffer_var, T.tvm_stack_make_shape(10, 10), "
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3692,6 +3692,30 @@ def func(
return func


def nested_seqstmt():
"""Nested SeqStmt should be normalized to flat SeqStmt

Nested SeqStmt are representable in the TIR structures, but are
flattened when converted to TVMScript. Previously, this could
cause failures to round-trip through TVMScript, including
erroneous use of TVMScript's concise-scoping rules. This was
resolved by normalizing nested SeqStmt in TIR, such that the use
of `tir.SeqStmt` below results in a single flat `tir.SeqStmt`
containing the three `tir.Evaluate` calls.
"""
func = tvm.tir.PrimFunc(
params=[],
body=tvm.tir.SeqStmt(
[
tvm.tir.SeqStmt([tvm.tir.Evaluate(0), tvm.tir.Evaluate(1)]),
tvm.tir.Evaluate(2),
]
),
)

return func


ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
Expand Down Expand Up @@ -3757,6 +3781,7 @@ def func(
merge_shape_var_def,
if_then_else_var,
tvm_shfl_builtins,
nested_seqstmt,
)


Expand Down