diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index b2f95ad2d590..c5bcda2effc9 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -103,6 +103,11 @@ class CustomDatatypesLowerer : public StmtExprMutator { } } + Stmt VisitStmt_(const DeclBufferNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { auto node = Downcast(StmtExprMutator::VisitExpr_(op)); auto modified = VisitBufferAccess(node); diff --git a/tests/python/unittest/test_custom_datatypes.py b/tests/python/unittest/test_custom_datatypes.py index 5c7a429317a5..41ccec5ad21f 100644 --- a/tests/python/unittest/test_custom_datatypes.py +++ b/tests/python/unittest/test_custom_datatypes.py @@ -34,6 +34,7 @@ register_op, ) from tvm.tir.op import call_pure_extern +from tvm.script import tir as T # note: we can't use relay.testing models because params are randomly initialized, @@ -116,88 +117,105 @@ def setup_myfloat(): Own Datatypes framework. """ - # To use datatype operations in an external library, you should first load - # the library containing the datatype implementation: - # CDLL("libposit.so", RTLD_GLOBAL) - # In this case, the datatype library we are using is built right into TVM, - # so we do not need to explicitly load any library. + def _setup_myfloat_inner(): + # To use datatype operations in an external library, you should first load + # the library containing the datatype implementation: + # CDLL("libposit.so", RTLD_GLOBAL) + # In this case, the datatype library we are using is built right into TVM, + # so we do not need to explicitly load any library. - # You can pick a code for your datatype arbitrarily, as long as it is - # greater than 128 and has not already been chosen. - register("myfloat", 131) + # You can pick a code for your datatype arbitrarily, as long as it is + # greater than 128 and has not already been chosen. + register("myfloat", 131) - register_op( - create_lower_func({(32, 32): "FloatToCustom32"}), "Cast", "llvm", "float", "myfloat" - ) - register_op( - create_lower_func({(32, 32): "Custom32ToFloat"}), "Cast", "llvm", "myfloat", "float" - ) - register_op(create_lower_func({32: "Custom32Add"}), "Add", "llvm", "myfloat") - register_op( - create_lower_func( - { - 32: "Custom32Sub", - } - ), - "Sub", - "llvm", - "myfloat", - ) - register_op(create_lower_func({32: "Custom32Mul"}), "Mul", "llvm", "myfloat") - register_op( - create_lower_func( - { - 32: "FloatToCustom32", - } - ), - "FloatImm", - "llvm", - "myfloat", - ) - register_op( - create_lower_func( - { - 32: "Custom32Div", - } - ), - "Div", - "llvm", - "myfloat", - ) - register_op(create_lower_func({32: "Custom32Max"}), "Max", "llvm", "myfloat") - register_op( - create_lower_func({32: "Custom32Sqrt"}), - "Call", - "llvm", - "myfloat", - intrinsic_name="tir.sqrt", - ) - register_op( - create_lower_func({32: "Custom32Exp"}), "Call", "llvm", "myfloat", intrinsic_name="tir.exp" - ) - register_op( - create_lower_func({32: "Custom32Log"}), "Call", "llvm", "myfloat", intrinsic_name="tir.log" - ) - register_op( - create_lower_func({32: "Custom32Sigmoid"}), - "Call", - "llvm", - "myfloat", - intrinsic_name="tir.sigmoid", - ) - register_op( - create_lower_func({32: "Custom32Tanh"}), - "Call", - "llvm", - "myfloat", - intrinsic_name="tir.tanh", - ) - register_op(lower_ite, "Call", "llvm", "myfloat", intrinsic_name="tir.if_then_else") - register_op( - lower_call_pure_extern, "Call", "llvm", "myfloat", intrinsic_name="tir.call_pure_extern" - ) + register_op( + create_lower_func({(32, 32): "FloatToCustom32"}), "Cast", "llvm", "float", "myfloat" + ) + register_op( + create_lower_func({(32, 32): "Custom32ToFloat"}), "Cast", "llvm", "myfloat", "float" + ) + register_op(create_lower_func({32: "Custom32Add"}), "Add", "llvm", "myfloat") + register_op( + create_lower_func( + { + 32: "Custom32Sub", + } + ), + "Sub", + "llvm", + "myfloat", + ) + register_op(create_lower_func({32: "Custom32Mul"}), "Mul", "llvm", "myfloat") + register_op( + create_lower_func( + { + 32: "FloatToCustom32", + } + ), + "FloatImm", + "llvm", + "myfloat", + ) + register_op( + create_lower_func( + { + 32: "Custom32Div", + } + ), + "Div", + "llvm", + "myfloat", + ) + register_op(create_lower_func({32: "Custom32Max"}), "Max", "llvm", "myfloat") + register_op( + create_lower_func({32: "Custom32Sqrt"}), + "Call", + "llvm", + "myfloat", + intrinsic_name="tir.sqrt", + ) + register_op( + create_lower_func({32: "Custom32Exp"}), + "Call", + "llvm", + "myfloat", + intrinsic_name="tir.exp", + ) + register_op( + create_lower_func({32: "Custom32Log"}), + "Call", + "llvm", + "myfloat", + intrinsic_name="tir.log", + ) + register_op( + create_lower_func({32: "Custom32Sigmoid"}), + "Call", + "llvm", + "myfloat", + intrinsic_name="tir.sigmoid", + ) + register_op( + create_lower_func({32: "Custom32Tanh"}), + "Call", + "llvm", + "myfloat", + intrinsic_name="tir.tanh", + ) + register_op(lower_ite, "Call", "llvm", "myfloat", intrinsic_name="tir.if_then_else") + register_op( + lower_call_pure_extern, "Call", "llvm", "myfloat", intrinsic_name="tir.call_pure_extern" + ) + + register_min_func(create_min_lower_func({32: "MinCustom32"}, "myfloat"), "myfloat") - register_min_func(create_min_lower_func({32: "MinCustom32"}, "myfloat"), "myfloat") + try: + _setup_myfloat_inner() + except tvm._ffi.base.TVMError as e: + # Ignore this specific error which can happen if another test + # that uses "myfloat" has already run. + if "float is already registered" not in str(e): + raise e def setup_posites2(): @@ -513,12 +531,8 @@ def run_batchnorm(src_dtype, dst_dtype, rtol=1e-6, atol=1e-6): def test_myfloat(): - try: - setup_myfloat() - except tvm._ffi.base.TVMError as e: - if "float is already registered" not in str(e): - # Ignore this specific error which can happen if this test runs twice within the same process - raise e + setup_myfloat() + run_ops("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6) run_conv2d("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6) run_batchnorm("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6) @@ -529,6 +543,82 @@ def test_myfloat(): # 'custom[myfloat]32') +class TestMyfloatLowering(tvm.testing.CompareBeforeAfter): + setup_myfloat() + + transform = tvm.tir.transform.LowerCustomDatatypes() + + def before(self): + dtype = "custom[myfloat]32" + + @T.prim_func + def func(A_data: T.handle(dtype)): + T.func_attr({"target": T.target("llvm")}) + A = T.Buffer(16, dtype=dtype, data=A_data) + B_data = T.allocate([16], dtype=dtype) + B = T.Buffer(16, dtype=dtype, data=B_data) + for i in range(16): + B[i] = A[i] + 1.0 + + return func + + def expected(self): + dtype = "custom[myfloat]32" + + @T.prim_func + def func(A_data: T.handle(dtype)): + T.func_attr({"target": T.target("llvm")}) + A_uint32 = T.Buffer(16, "uint32", data=A_data) + B_data = T.allocate([16], dtype="uint32") + B_uint32 = T.Buffer(16, "uint32", data=B_data) + for i in range(16): + B_uint32[i] = T.call_pure_extern( + "uint32", + "FloatToCustom32", + T.call_pure_extern("float32", "Custom32ToFloat", A_uint32[i]) + T.float32(1), + ) + + return func + + +class TestMyfloatLoweringDeclBuffer(tvm.testing.CompareBeforeAfter): + """Like TestMyfloatLoweringDeclBuffer, but using DeclBuffer""" + + setup_myfloat() + + transform = tvm.tir.transform.LowerCustomDatatypes() + + def before(self): + dtype = "custom[myfloat]32" + + @T.prim_func + def func(A_data: T.handle(dtype)): + T.func_attr({"target": T.target("llvm")}) + A = T.decl_buffer(16, dtype=dtype, data=A_data) + B = T.decl_buffer(16, dtype=dtype) + for i in range(16): + B[i] = A[i] + 1.0 + + return func + + def expected(self): + dtype = "custom[myfloat]32" + + @T.prim_func + def func(A_data: T.handle(dtype)): + T.func_attr({"target": T.target("llvm")}) + A_uint32 = T.decl_buffer(16, "uint32", data=A_data) + B_uint32 = T.decl_buffer(16, dtype="uint32") + for i in range(16): + B_uint32[i] = T.call_pure_extern( + "uint32", + "FloatToCustom32", + T.call_pure_extern("float32", "Custom32ToFloat", A_uint32[i]) + T.float32(1), + ) + + return func + + def _has_posit(): return tvm.support.libinfo()["USE_BYODT_POSIT"] == "ON"