Skip to content

Commit

Permalink
[TVMSCRIPT]Fix script printters StructuralEqual check failed (#8499)
Browse files Browse the repository at this point in the history

Co-authored-by: honghua.cao <honghua.cao@streamcomputing.com>
  • Loading branch information
Beya2019 and honghua.cao authored Jul 22, 2021
1 parent 59e96e0 commit 07243a8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,16 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc doc;
std::ostringstream os;
if (dtype.is_float() || dtype.is_float16() || dtype.is_bfloat16()) {
os.setf(std::ios::showpoint);
os.precision(17);
}
os << data[0];
if (dtype == DataType::Int(32)) {
doc << Doc::Text(os.str());
} else if (dtype == DataType::Bool()) {
doc << Doc::Text(data[0] ? "True" : "False");
} else if (dtype == DataType::Float(32)) {
doc << Doc::Text(os.str());
} else {
doc << "tir." << runtime::DLDataType2String(dtype) << "(" << Doc::Text(os.str()) << ")";
}
Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2960,6 +2960,20 @@ def test_abs():
tvm.ir.assert_structural_equal(func, rt_func)


@tvm.script.tir
def printer(a: ty.handle) -> None:
A = tir.match_buffer(a, (), "float32")
A[()] = tir.min(2.2, 5.2)
A[()] = tir.max(tir.float32(2.2), tir.float32(5.2))
A[()] = tir.min(2.2, 5.0)


def test_script_printer():
func = printer
rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
tvm.ir.assert_structural_equal(func, rt_func)


if __name__ == "__main__":
test_opt_gemm_normalize()
test_opt_gemm_mod_host()
Expand All @@ -2977,3 +2991,4 @@ def test_abs():
test_block_elements()
test_opaque_block()
test_abs()
test_script_printer()

0 comments on commit 07243a8

Please sign in to comment.