diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 01f79bd0c7..103aee9911 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -235,6 +235,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc doc; std::ostringstream os; if (dtype.is_float() || dtype.is_float16() || dtype.is_bfloat16()) { + os.setf(std::ios::showpoint); os.precision(17); } os << data[0]; @@ -242,6 +243,8 @@ class TVMScriptPrinter : public StmtFunctor, doc << Doc::Text(os.str()); } else if (dtype == DataType::Bool()) { doc << Doc::Text(data[0] ? "True" : "False"); + } else if (dtype == DataType::Float(32)) { + doc << Doc::Text(os.str()); } else { doc << "tir." << runtime::DLDataType2String(dtype) << "(" << Doc::Text(os.str()) << ")"; } diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 0a9618cb8c..05bf07296e 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -2960,6 +2960,20 @@ def test_abs(): tvm.ir.assert_structural_equal(func, rt_func) +@tvm.script.tir +def printer(a: ty.handle) -> None: + A = tir.match_buffer(a, (), "float32") + A[()] = tir.min(2.2, 5.2) + A[()] = tir.max(tir.float32(2.2), tir.float32(5.2)) + A[()] = tir.min(2.2, 5.0) + + +def test_script_printer(): + func = printer + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + if __name__ == "__main__": test_opt_gemm_normalize() test_opt_gemm_mod_host() @@ -2977,3 +2991,4 @@ def test_abs(): test_block_elements() test_opaque_block() test_abs() + test_script_printer()