From 73907d1e56f0cd1339029a8d9bfbbda6e1829cd8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 12 May 2023 11:42:32 -0500 Subject: [PATCH 1/3] [TIR] Handle DeclBuffer in LowerCustomDatatypes Preserve DeclBuffer node when transforming with `LowerCustomDatatypes` This is a subset of changes, being split out from https://github.com/apache/tvm/pull/14778 into independent portions. --- src/tir/transforms/lower_custom_datatypes.cc | 5 + .../python/unittest/test_custom_datatypes.py | 236 +++++++++++------- 2 files changed, 156 insertions(+), 85 deletions(-) diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index b2f95ad2d590..c5bcda2effc9 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -103,6 +103,11 @@ class CustomDatatypesLowerer : public StmtExprMutator { } } + Stmt VisitStmt_(const DeclBufferNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { auto node = Downcast(StmtExprMutator::VisitExpr_(op)); auto modified = VisitBufferAccess(node); diff --git a/tests/python/unittest/test_custom_datatypes.py b/tests/python/unittest/test_custom_datatypes.py index 5c7a429317a5..268ac0501b32 100644 --- a/tests/python/unittest/test_custom_datatypes.py +++ b/tests/python/unittest/test_custom_datatypes.py @@ -34,6 +34,7 @@ register_op, ) from tvm.tir.op import call_pure_extern +from tvm.script import tir as T # note: we can't use relay.testing models because params are randomly initialized, @@ -116,88 +117,105 @@ def setup_myfloat(): Own Datatypes framework. """ - # To use datatype operations in an external library, you should first load - # the library containing the datatype implementation: - # CDLL("libposit.so", RTLD_GLOBAL) - # In this case, the datatype library we are using is built right into TVM, - # so we do not need to explicitly load any library. + def _setup_myfloat_inner(): + # To use datatype operations in an external library, you should first load + # the library containing the datatype implementation: + # CDLL("libposit.so", RTLD_GLOBAL) + # In this case, the datatype library we are using is built right into TVM, + # so we do not need to explicitly load any library. - # You can pick a code for your datatype arbitrarily, as long as it is - # greater than 128 and has not already been chosen. - register("myfloat", 131) + # You can pick a code for your datatype arbitrarily, as long as it is + # greater than 128 and has not already been chosen. + register("myfloat", 131) - register_op( - create_lower_func({(32, 32): "FloatToCustom32"}), "Cast", "llvm", "float", "myfloat" - ) - register_op( - create_lower_func({(32, 32): "Custom32ToFloat"}), "Cast", "llvm", "myfloat", "float" - ) - register_op(create_lower_func({32: "Custom32Add"}), "Add", "llvm", "myfloat") - register_op( - create_lower_func( - { - 32: "Custom32Sub", - } - ), - "Sub", - "llvm", - "myfloat", - ) - register_op(create_lower_func({32: "Custom32Mul"}), "Mul", "llvm", "myfloat") - register_op( - create_lower_func( - { - 32: "FloatToCustom32", - } - ), - "FloatImm", - "llvm", - "myfloat", - ) - register_op( - create_lower_func( - { - 32: "Custom32Div", - } - ), - "Div", - "llvm", - "myfloat", - ) - register_op(create_lower_func({32: "Custom32Max"}), "Max", "llvm", "myfloat") - register_op( - create_lower_func({32: "Custom32Sqrt"}), - "Call", - "llvm", - "myfloat", - intrinsic_name="tir.sqrt", - ) - register_op( - create_lower_func({32: "Custom32Exp"}), "Call", "llvm", "myfloat", intrinsic_name="tir.exp" - ) - register_op( - create_lower_func({32: "Custom32Log"}), "Call", "llvm", "myfloat", intrinsic_name="tir.log" - ) - register_op( - create_lower_func({32: "Custom32Sigmoid"}), - "Call", - "llvm", - "myfloat", - intrinsic_name="tir.sigmoid", - ) - register_op( - create_lower_func({32: "Custom32Tanh"}), - "Call", - "llvm", - "myfloat", - intrinsic_name="tir.tanh", - ) - register_op(lower_ite, "Call", "llvm", "myfloat", intrinsic_name="tir.if_then_else") - register_op( - lower_call_pure_extern, "Call", "llvm", "myfloat", intrinsic_name="tir.call_pure_extern" - ) + register_op( + create_lower_func({(32, 32): "FloatToCustom32"}), "Cast", "llvm", "float", "myfloat" + ) + register_op( + create_lower_func({(32, 32): "Custom32ToFloat"}), "Cast", "llvm", "myfloat", "float" + ) + register_op(create_lower_func({32: "Custom32Add"}), "Add", "llvm", "myfloat") + register_op( + create_lower_func( + { + 32: "Custom32Sub", + } + ), + "Sub", + "llvm", + "myfloat", + ) + register_op(create_lower_func({32: "Custom32Mul"}), "Mul", "llvm", "myfloat") + register_op( + create_lower_func( + { + 32: "FloatToCustom32", + } + ), + "FloatImm", + "llvm", + "myfloat", + ) + register_op( + create_lower_func( + { + 32: "Custom32Div", + } + ), + "Div", + "llvm", + "myfloat", + ) + register_op(create_lower_func({32: "Custom32Max"}), "Max", "llvm", "myfloat") + register_op( + create_lower_func({32: "Custom32Sqrt"}), + "Call", + "llvm", + "myfloat", + intrinsic_name="tir.sqrt", + ) + register_op( + create_lower_func({32: "Custom32Exp"}), + "Call", + "llvm", + "myfloat", + intrinsic_name="tir.exp", + ) + register_op( + create_lower_func({32: "Custom32Log"}), + "Call", + "llvm", + "myfloat", + intrinsic_name="tir.log", + ) + register_op( + create_lower_func({32: "Custom32Sigmoid"}), + "Call", + "llvm", + "myfloat", + intrinsic_name="tir.sigmoid", + ) + register_op( + create_lower_func({32: "Custom32Tanh"}), + "Call", + "llvm", + "myfloat", + intrinsic_name="tir.tanh", + ) + register_op(lower_ite, "Call", "llvm", "myfloat", intrinsic_name="tir.if_then_else") + register_op( + lower_call_pure_extern, "Call", "llvm", "myfloat", intrinsic_name="tir.call_pure_extern" + ) - register_min_func(create_min_lower_func({32: "MinCustom32"}, "myfloat"), "myfloat") + register_min_func(create_min_lower_func({32: "MinCustom32"}, "myfloat"), "myfloat") + + try: + _setup_myfloat_inner() + except tvm._ffi.base.TVMError as e: + # Ignore this specific error which can happen if another test + # that uses "myfloat" has already run. + if "float is already registered" not in str(e): + raise e def setup_posites2(): @@ -513,12 +531,8 @@ def run_batchnorm(src_dtype, dst_dtype, rtol=1e-6, atol=1e-6): def test_myfloat(): - try: - setup_myfloat() - except tvm._ffi.base.TVMError as e: - if "float is already registered" not in str(e): - # Ignore this specific error which can happen if this test runs twice within the same process - raise e + setup_myfloat() + run_ops("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6) run_conv2d("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6) run_batchnorm("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6) @@ -529,6 +543,58 @@ def test_myfloat(): # 'custom[myfloat]32') +class TestMyfloatLowering(tvm.testing.CompareBeforeAfter): + setup_myfloat() + + transform = tvm.tir.transform.LowerCustomDatatypes() + + def before(A_data: T.handle("custom[myfloat]32")): + T.func_attr({"target": T.target("llvm")}) + A = T.Buffer(16, dtype="custom[myfloat]32", data=A_data) + B_data = T.allocate([16], dtype="custom[myfloat]32") + B = T.Buffer(16, dtype="custom[myfloat]32", data=B_data) + for i in range(16): + B[i] = A[i] + 1.0 + + def expected(A_data: T.handle("custom[myfloat]32")): + T.func_attr({"target": T.target("llvm")}) + A_uint32 = T.Buffer(16, "uint32", data=A_data) + B_data = T.allocate([16], dtype="uint32") + B_uint32 = T.Buffer(16, "uint32", data=B_data) + for i in range(16): + B_uint32[i] = T.call_pure_extern( + "uint32", + "FloatToCustom32", + T.call_pure_extern("float32", "Custom32ToFloat", A_uint32[i]) + T.float32(1), + ) + + +class TestMyfloatLoweringDeclBuffer(tvm.testing.CompareBeforeAfter): + """Like TestMyfloatLoweringDeclBuffer, but using DeclBuffer""" + + setup_myfloat() + + transform = tvm.tir.transform.LowerCustomDatatypes() + + def before(A_data: T.handle("custom[myfloat]32")): + T.func_attr({"target": T.target("llvm")}) + A = T.decl_buffer(16, dtype="custom[myfloat]32", data=A_data) + B = T.decl_buffer(16, dtype="custom[myfloat]32") + for i in range(16): + B[i] = A[i] + 1.0 + + def expected(A_data: T.handle("custom[myfloat]32")): + T.func_attr({"target": T.target("llvm")}) + A_uint32 = T.decl_buffer(16, "uint32", data=A_data) + B_uint32 = T.decl_buffer(16, dtype="uint32") + for i in range(16): + B_uint32[i] = T.call_pure_extern( + "uint32", + "FloatToCustom32", + T.call_pure_extern("float32", "Custom32ToFloat", A_uint32[i]) + T.float32(1), + ) + + def _has_posit(): return tvm.support.libinfo()["USE_BYODT_POSIT"] == "ON" From efce996b98aefea50b8dbce9c8d7d2618af31d72 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Jun 2023 14:47:33 -0500 Subject: [PATCH 2/3] Fix lint error --- .../python/unittest/test_custom_datatypes.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/python/unittest/test_custom_datatypes.py b/tests/python/unittest/test_custom_datatypes.py index 268ac0501b32..31a199026f1d 100644 --- a/tests/python/unittest/test_custom_datatypes.py +++ b/tests/python/unittest/test_custom_datatypes.py @@ -547,16 +547,17 @@ class TestMyfloatLowering(tvm.testing.CompareBeforeAfter): setup_myfloat() transform = tvm.tir.transform.LowerCustomDatatypes() + dtype = "custom[myfloat]32" - def before(A_data: T.handle("custom[myfloat]32")): + def before(A_data: T.handle(dtype)): T.func_attr({"target": T.target("llvm")}) - A = T.Buffer(16, dtype="custom[myfloat]32", data=A_data) - B_data = T.allocate([16], dtype="custom[myfloat]32") - B = T.Buffer(16, dtype="custom[myfloat]32", data=B_data) + A = T.Buffer(16, dtype=dtype, data=A_data) + B_data = T.allocate([16], dtype=dtype) + B = T.Buffer(16, dtype=dtype, data=B_data) for i in range(16): B[i] = A[i] + 1.0 - def expected(A_data: T.handle("custom[myfloat]32")): + def expected(A_data: T.handle(dtype)): T.func_attr({"target": T.target("llvm")}) A_uint32 = T.Buffer(16, "uint32", data=A_data) B_data = T.allocate([16], dtype="uint32") @@ -576,14 +577,16 @@ class TestMyfloatLoweringDeclBuffer(tvm.testing.CompareBeforeAfter): transform = tvm.tir.transform.LowerCustomDatatypes() - def before(A_data: T.handle("custom[myfloat]32")): + dtype = "custom[myfloat]32" + + def before(A_data: T.handle(dtype)): T.func_attr({"target": T.target("llvm")}) - A = T.decl_buffer(16, dtype="custom[myfloat]32", data=A_data) - B = T.decl_buffer(16, dtype="custom[myfloat]32") + A = T.decl_buffer(16, dtype=dtype, data=A_data) + B = T.decl_buffer(16, dtype=dtype) for i in range(16): B[i] = A[i] + 1.0 - def expected(A_data: T.handle("custom[myfloat]32")): + def expected(A_data: T.handle(dtype)): T.func_attr({"target": T.target("llvm")}) A_uint32 = T.decl_buffer(16, "uint32", data=A_data) B_uint32 = T.decl_buffer(16, dtype="uint32") From ecf3497cfa92a79c56d28c066f5b94c032637276 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 7 Jun 2023 07:11:03 -0500 Subject: [PATCH 3/3] Fix parsing error introduced by lint fix --- .../python/unittest/test_custom_datatypes.py | 101 +++++++++++------- 1 file changed, 61 insertions(+), 40 deletions(-) diff --git a/tests/python/unittest/test_custom_datatypes.py b/tests/python/unittest/test_custom_datatypes.py index 31a199026f1d..41ccec5ad21f 100644 --- a/tests/python/unittest/test_custom_datatypes.py +++ b/tests/python/unittest/test_custom_datatypes.py @@ -547,27 +547,38 @@ class TestMyfloatLowering(tvm.testing.CompareBeforeAfter): setup_myfloat() transform = tvm.tir.transform.LowerCustomDatatypes() - dtype = "custom[myfloat]32" - - def before(A_data: T.handle(dtype)): - T.func_attr({"target": T.target("llvm")}) - A = T.Buffer(16, dtype=dtype, data=A_data) - B_data = T.allocate([16], dtype=dtype) - B = T.Buffer(16, dtype=dtype, data=B_data) - for i in range(16): - B[i] = A[i] + 1.0 - - def expected(A_data: T.handle(dtype)): - T.func_attr({"target": T.target("llvm")}) - A_uint32 = T.Buffer(16, "uint32", data=A_data) - B_data = T.allocate([16], dtype="uint32") - B_uint32 = T.Buffer(16, "uint32", data=B_data) - for i in range(16): - B_uint32[i] = T.call_pure_extern( - "uint32", - "FloatToCustom32", - T.call_pure_extern("float32", "Custom32ToFloat", A_uint32[i]) + T.float32(1), - ) + + def before(self): + dtype = "custom[myfloat]32" + + @T.prim_func + def func(A_data: T.handle(dtype)): + T.func_attr({"target": T.target("llvm")}) + A = T.Buffer(16, dtype=dtype, data=A_data) + B_data = T.allocate([16], dtype=dtype) + B = T.Buffer(16, dtype=dtype, data=B_data) + for i in range(16): + B[i] = A[i] + 1.0 + + return func + + def expected(self): + dtype = "custom[myfloat]32" + + @T.prim_func + def func(A_data: T.handle(dtype)): + T.func_attr({"target": T.target("llvm")}) + A_uint32 = T.Buffer(16, "uint32", data=A_data) + B_data = T.allocate([16], dtype="uint32") + B_uint32 = T.Buffer(16, "uint32", data=B_data) + for i in range(16): + B_uint32[i] = T.call_pure_extern( + "uint32", + "FloatToCustom32", + T.call_pure_extern("float32", "Custom32ToFloat", A_uint32[i]) + T.float32(1), + ) + + return func class TestMyfloatLoweringDeclBuffer(tvm.testing.CompareBeforeAfter): @@ -577,25 +588,35 @@ class TestMyfloatLoweringDeclBuffer(tvm.testing.CompareBeforeAfter): transform = tvm.tir.transform.LowerCustomDatatypes() - dtype = "custom[myfloat]32" - - def before(A_data: T.handle(dtype)): - T.func_attr({"target": T.target("llvm")}) - A = T.decl_buffer(16, dtype=dtype, data=A_data) - B = T.decl_buffer(16, dtype=dtype) - for i in range(16): - B[i] = A[i] + 1.0 - - def expected(A_data: T.handle(dtype)): - T.func_attr({"target": T.target("llvm")}) - A_uint32 = T.decl_buffer(16, "uint32", data=A_data) - B_uint32 = T.decl_buffer(16, dtype="uint32") - for i in range(16): - B_uint32[i] = T.call_pure_extern( - "uint32", - "FloatToCustom32", - T.call_pure_extern("float32", "Custom32ToFloat", A_uint32[i]) + T.float32(1), - ) + def before(self): + dtype = "custom[myfloat]32" + + @T.prim_func + def func(A_data: T.handle(dtype)): + T.func_attr({"target": T.target("llvm")}) + A = T.decl_buffer(16, dtype=dtype, data=A_data) + B = T.decl_buffer(16, dtype=dtype) + for i in range(16): + B[i] = A[i] + 1.0 + + return func + + def expected(self): + dtype = "custom[myfloat]32" + + @T.prim_func + def func(A_data: T.handle(dtype)): + T.func_attr({"target": T.target("llvm")}) + A_uint32 = T.decl_buffer(16, "uint32", data=A_data) + B_uint32 = T.decl_buffer(16, dtype="uint32") + for i in range(16): + B_uint32[i] = T.call_pure_extern( + "uint32", + "FloatToCustom32", + T.call_pure_extern("float32", "Custom32ToFloat", A_uint32[i]) + T.float32(1), + ) + + return func def _has_posit():