diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index e3148a26c29a..5b7a600c7883 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -74,6 +74,7 @@ def test_something(): import pickle import platform import sys +import textwrap import time import shutil @@ -1712,3 +1713,208 @@ def fetch_model_from_url( def main(): test_file = inspect.getsourcefile(sys._getframe(1)) sys.exit(pytest.main([test_file] + sys.argv[1:])) + + +class CompareBeforeAfter: + """Utility for comparing before/after of TIR transforms + + A standard framework for writing tests that take a TIR PrimFunc as + input, apply a transformation, then either compare against an + expected output or assert that the transformation raised an error. + A test should subclass CompareBeforeAfter, defining class members + `before`, `transform`, and `expected`. CompareBeforeAfter will + then use these members to define a test method and test fixture. + + `transform` may be one of the following. + + - An instance of `tvm.ir.transform.Pass` + + - A method that takes no arguments and returns a `tvm.ir.transform.Pass` + + - A pytest fixture that returns a `tvm.ir.transform.Pass` + + `before` may be any one of the following. + + - An instance of `tvm.tir.PrimFunc`. This is allowed, but is not + the preferred method, as any errors in constructing the + `PrimFunc` occur while collecting the test, preventing any other + tests in the same file from being run. + + - An TVMScript function, without the ``@T.prim_func`` decoration. + The ``@T.prim_func`` decoration will be applied when running the + test, rather than at module import. + + - A method that takes no arguments and returns a `tvm.tir.PrimFunc` + + - A pytest fixture that returns a `tvm.tir.PrimFunc` + + `expected` may be any one of the following. The type of + `expected` defines the test being performed. If `expected` + provides a `tvm.tir.PrimFunc`, the result of the transformation + must match `expected`. If `expected` is an exception, then the + transformation must raise that exception type. + + - Any option supported for `before`. + + - The `Exception` class object, or a class object that inherits + from `Exception`. + + - A method that takes no arguments and returns `Exception` or a + class object that inherits from `Exception`. + + - A pytest fixture that returns `Exception` or an class object + that inherits from `Exception`. + + Examples + -------- + + .. python:: + + class TestRemoveIf(tvm.testing.CompareBeforeAfter): + transform = tvm.tir.transform.Simplify() + + def before(A: T.Buffer[1, "int32"]): + if True: + A[0] = 42 + else: + A[0] = 5 + + def expected(A: T.Buffer[1, "int32"]): + A[0] = 42 + + """ + + def __init_subclass__(cls): + if hasattr(cls, "before"): + cls.before = cls._normalize_before(cls.before) + if hasattr(cls, "expected"): + cls.expected = cls._normalize_expected(cls.expected) + if hasattr(cls, "transform"): + cls.transform = cls._normalize_transform(cls.transform) + + @classmethod + def _normalize_before(cls, func): + if hasattr(func, "_pytestfixturefunction"): + return func + + if isinstance(func, tvm.tir.PrimFunc): + + def inner(self): + # pylint: disable=unused-argument + return func + + elif cls._is_method(func): + + def inner(self): + # pylint: disable=unused-argument + return func(self) + + else: + + def inner(self): + # pylint: disable=unused-argument + source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func)) + return tvm.script.from_source(source_code) + + return pytest.fixture(inner) + + @classmethod + def _normalize_expected(cls, func): + if hasattr(func, "_pytestfixturefunction"): + return func + + if isinstance(func, tvm.tir.PrimFunc) or ( + inspect.isclass(func) and issubclass(func, Exception) + ): + + def inner(self): + # pylint: disable=unused-argument + return func + + elif cls._is_method(func): + + def inner(self): + # pylint: disable=unused-argument + return func(self) + + else: + + def inner(self): + # pylint: disable=unused-argument + source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func)) + return tvm.script.from_source(source_code) + + return pytest.fixture(inner) + + @classmethod + def _normalize_transform(cls, transform): + if hasattr(transform, "_pytestfixturefunction"): + return transform + + if isinstance(transform, tvm.ir.transform.Pass): + + def inner(self): + # pylint: disable=unused-argument + return transform + + elif cls._is_method(transform): + + def inner(self): + # pylint: disable=unused-argument + return transform(self) + + else: + + raise TypeError( + "Expected transform to be a tvm.ir.transform.Pass, or a method returning a Pass" + ) + + return pytest.fixture(inner) + + @staticmethod + def _is_method(func): + sig = inspect.signature(func) + return "self" in sig.parameters + + def test_compare(self, before, expected, transform): + """Unit test to compare the expected TIR PrimFunc to actual""" + + before_mod = tvm.IRModule.from_expr(before) + + if inspect.isclass(expected) and issubclass(expected, Exception): + with pytest.raises(expected): + after_mod = transform(before_mod) + + # This portion through pytest.fail isn't strictly + # necessary, but gives a better error message that + # includes the before/after. + after = after_mod["main"] + script = tvm.IRModule({"after": after, "before": before}).script() + pytest.fail( + msg=( + f"Expected {expected.__name__} to be raised from transformation, " + f"instead received TIR\n:{script}" + ) + ) + + elif isinstance(expected, tvm.tir.PrimFunc): + after_mod = transform(before_mod) + after = after_mod["main"] + + try: + tvm.ir.assert_structural_equal(after, expected) + except ValueError as err: + script = tvm.IRModule( + {"expected": expected, "after": after, "before": before} + ).script() + raise ValueError( + f"TIR after transformation did not match expected:\n{script}" + ) from err + + else: + raise TypeError( + f"tvm.testing.CompareBeforeAfter requires the `expected` fixture " + f"to return either `Exception`, an `Exception` subclass, " + f"or an instance of `tvm.tir.PrimFunc`. " + f"Instead, received {type(exception)}." + ) diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 529b45481177..4ac502b21191 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -136,31 +136,16 @@ def sls(n, d): assert "if" not in str(stmt) -class BaseBeforeAfter: - def test_simplify(self): - before = self.before - before_mod = tvm.IRModule.from_expr(before) - after_mod = tvm.tir.transform.Simplify()(before_mod) - after = after_mod["main"] - expected = self.expected - - try: - tvm.ir.assert_structural_equal(after, expected) - except ValueError as err: - script = tvm.IRModule({"expected": expected, "after": after, "before": before}).script() - raise ValueError( - f"Function after simplification did not match expected:\n{script}" - ) from err +class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): + transform = tvm.tir.transform.Simplify() class TestLoadStoreNoop(BaseBeforeAfter): """Store of a value that was just read from the same location is a no-op.""" - @T.prim_func def before(A: T.Buffer[(1,), "float32"]): A[0] = A[0] - @T.prim_func def expected(A: T.Buffer[(1,), "float32"]): T.evaluate(0) @@ -174,11 +159,9 @@ class TestLoadStoreNoopAfterSimplify(BaseBeforeAfter): regression. """ - @T.prim_func def before(A: T.Buffer[(1,), "float32"]): A[0] = A[0] + (5.0 - 5.0) - @T.prim_func def expected(A: T.Buffer[(1,), "float32"]): T.evaluate(0) @@ -191,14 +174,12 @@ class TestNestedCondition(BaseBeforeAfter): constraint. """ - @T.prim_func def before(A: T.Buffer[(16,), "float32"]): for i in T.serial(16): if i == 5: if i == 5: A[i] = 0.0 - @T.prim_func def expected(A: T.Buffer[(16,), "float32"]): for i in T.serial(16): if i == 5: @@ -212,14 +193,12 @@ class TestNestedProvableCondition(BaseBeforeAfter): conditional. """ - @T.prim_func def before(A: T.Buffer[(16,), "float32"]): for i in T.serial(16): if i == 5: if i < 7: A[i] = 0.0 - @T.prim_func def expected(A: T.Buffer[(16,), "float32"]): for i in T.serial(16): if i == 5: @@ -233,14 +212,12 @@ class TestNestedVarCondition(BaseBeforeAfter): constraint. """ - @T.prim_func def before(A: T.Buffer[(16,), "float32"], n: T.int32): for i in T.serial(16): if i == n: if i == n: A[i] = 0.0 - @T.prim_func def expected(A: T.Buffer[(16,), "float32"], n: T.int32): for i in T.serial(16): if i == n: @@ -256,7 +233,6 @@ class TestAlteredBufferContents(BaseBeforeAfter): may not. """ - @T.prim_func def before(A: T.Buffer[(1,), "int32"], n: T.int32): if A[0] == n: A[0] = A[0] + 1 @@ -273,7 +249,6 @@ class TestNegationOfCondition(BaseBeforeAfter): condition is known to be false. """ - @T.prim_func def before(A: T.Buffer[(16,), "int32"]): for i in T.serial(16): if i == 5: @@ -282,7 +257,6 @@ def before(A: T.Buffer[(16,), "int32"]): else: A[i] = 1 - @T.prim_func def expected(A: T.Buffer[(16,), "int32"]): for i in T.serial(16): if i == 5: @@ -298,7 +272,6 @@ class TestNegationOfNotEqual(BaseBeforeAfter): ``i==5`` as the negation of a literal constraint. """ - @T.prim_func def before(A: T.Buffer[(16,), "int32"]): for i in T.serial(16): if i != 5: @@ -307,7 +280,6 @@ def before(A: T.Buffer[(16,), "int32"]): else: A[i] = 1 - @T.prim_func def expected(A: T.Buffer[(16,), "int32"]): for i in T.serial(16): if i != 5: @@ -321,7 +293,6 @@ class TestNegationOfVarCondition(BaseBeforeAfter): must rely on RewriteSimplifier recognizing the repeated literal. """ - @T.prim_func def before(A: T.Buffer[(16,), "int32"], n: T.int32): for i in T.serial(16): if i == n: @@ -330,7 +301,6 @@ def before(A: T.Buffer[(16,), "int32"], n: T.int32): else: A[i] = 1 - @T.prim_func def expected(A: T.Buffer[(16,), "int32"], n: T.int32): for i in T.serial(16): if i == n: @@ -346,14 +316,12 @@ class TestLiteralConstraintSplitBooleanAnd(BaseBeforeAfter): the condition is to ensure we exercise RewriteSimplifier. """ - @T.prim_func def before(A: T.Buffer[(16, 16), "int32"], n: T.int32): for i, j in T.grid(16, 16): if i == n and j == n: if i == n: A[i, j] = 0 - @T.prim_func def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32): for i, j in T.grid(16, 16): if i == n and j == n: @@ -371,7 +339,6 @@ class TestLiteralConstraintSplitBooleanOr(BaseBeforeAfter): RewriteSimplifier. """ - @T.prim_func def before(A: T.Buffer[(16, 16), "int32"], n: T.int32): for i, j in T.grid(16, 16): if i == n or j == n: @@ -382,7 +349,6 @@ def before(A: T.Buffer[(16, 16), "int32"], n: T.int32): else: A[i, j] = 2 - @T.prim_func def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32): for i, j in T.grid(16, 16): if i == n or j == n: diff --git a/tests/python/unittest/test_tvm_testing_before_after.py b/tests/python/unittest/test_tvm_testing_before_after.py new file mode 100644 index 000000000000..613d66ccdb2b --- /dev/null +++ b/tests/python/unittest/test_tvm_testing_before_after.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import tvm +import tvm.testing +from tvm.script import tir as T + + +class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): + def transform(self): + return lambda x: x + + +class TestBeforeAfterPrimFunc(BaseBeforeAfter): + @T.prim_func + def before(): + T.evaluate(0) + + expected = before + + +class TestBeforeAfterMethod(BaseBeforeAfter): + def before(self): + @T.prim_func + def func(): + T.evaluate(0) + + return func + + expected = before + + +class TestBeforeAfterFixture(BaseBeforeAfter): + @tvm.testing.fixture + def before(self): + @T.prim_func + def func(): + T.evaluate(0) + + return func + + expected = before + + +class TestBeforeAfterDelayedPrimFunc(BaseBeforeAfter): + def before(): + T.evaluate(0) + + expected = before + + +class TestBeforeAfterParametrizedFixture(BaseBeforeAfter): + n = tvm.testing.parameter(1, 8, 16) + + @tvm.testing.fixture + def before(self, n): + @T.prim_func + def func(A: T.Buffer[n, "float32"]): + for i in T.serial(n): + A[i] = 0.0 + + return func + + expected = before + + +if __name__ == "__main__": + tvm.testing.main()