Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[UnitTest][TIR] Testing utility for before/after transform tests (apa…
Browse files Browse the repository at this point in the history
…che#12264)

This PR adds `tvm.testing.CompareBeforeAfter`, a generalization of the `BaseBeforeAfter` utility previously used in `test_tir_transform_simplify.py`, which performs unit tests that perform a transformation on a TIR function and compare the results to an expected TIR output.  This arose when minimizing the boilerplate required for unit tests in the implementation of apache#12261.
  • Loading branch information
Lunderberg authored and xinetzone committed Nov 25, 2022
1 parent 2414801 commit 2b438ac
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 36 deletions.
206 changes: 206 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_something():
import pickle
import platform
import sys
import textwrap
import time
import shutil

Expand Down Expand Up @@ -1712,3 +1713,208 @@ def fetch_model_from_url(
def main():
test_file = inspect.getsourcefile(sys._getframe(1))
sys.exit(pytest.main([test_file] + sys.argv[1:]))


class CompareBeforeAfter:
"""Utility for comparing before/after of TIR transforms
A standard framework for writing tests that take a TIR PrimFunc as
input, apply a transformation, then either compare against an
expected output or assert that the transformation raised an error.
A test should subclass CompareBeforeAfter, defining class members
`before`, `transform`, and `expected`. CompareBeforeAfter will
then use these members to define a test method and test fixture.
`transform` may be one of the following.
- An instance of `tvm.ir.transform.Pass`
- A method that takes no arguments and returns a `tvm.ir.transform.Pass`
- A pytest fixture that returns a `tvm.ir.transform.Pass`
`before` may be any one of the following.
- An instance of `tvm.tir.PrimFunc`. This is allowed, but is not
the preferred method, as any errors in constructing the
`PrimFunc` occur while collecting the test, preventing any other
tests in the same file from being run.
- An TVMScript function, without the ``@T.prim_func`` decoration.
The ``@T.prim_func`` decoration will be applied when running the
test, rather than at module import.
- A method that takes no arguments and returns a `tvm.tir.PrimFunc`
- A pytest fixture that returns a `tvm.tir.PrimFunc`
`expected` may be any one of the following. The type of
`expected` defines the test being performed. If `expected`
provides a `tvm.tir.PrimFunc`, the result of the transformation
must match `expected`. If `expected` is an exception, then the
transformation must raise that exception type.
- Any option supported for `before`.
- The `Exception` class object, or a class object that inherits
from `Exception`.
- A method that takes no arguments and returns `Exception` or a
class object that inherits from `Exception`.
- A pytest fixture that returns `Exception` or an class object
that inherits from `Exception`.
Examples
--------
.. python::
class TestRemoveIf(tvm.testing.CompareBeforeAfter):
transform = tvm.tir.transform.Simplify()
def before(A: T.Buffer[1, "int32"]):
if True:
A[0] = 42
else:
A[0] = 5
def expected(A: T.Buffer[1, "int32"]):
A[0] = 42
"""

def __init_subclass__(cls):
if hasattr(cls, "before"):
cls.before = cls._normalize_before(cls.before)
if hasattr(cls, "expected"):
cls.expected = cls._normalize_expected(cls.expected)
if hasattr(cls, "transform"):
cls.transform = cls._normalize_transform(cls.transform)

@classmethod
def _normalize_before(cls, func):
if hasattr(func, "_pytestfixturefunction"):
return func

if isinstance(func, tvm.tir.PrimFunc):

def inner(self):
# pylint: disable=unused-argument
return func

elif cls._is_method(func):

def inner(self):
# pylint: disable=unused-argument
return func(self)

else:

def inner(self):
# pylint: disable=unused-argument
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func))
return tvm.script.from_source(source_code)

return pytest.fixture(inner)

@classmethod
def _normalize_expected(cls, func):
if hasattr(func, "_pytestfixturefunction"):
return func

if isinstance(func, tvm.tir.PrimFunc) or (
inspect.isclass(func) and issubclass(func, Exception)
):

def inner(self):
# pylint: disable=unused-argument
return func

elif cls._is_method(func):

def inner(self):
# pylint: disable=unused-argument
return func(self)

else:

def inner(self):
# pylint: disable=unused-argument
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func))
return tvm.script.from_source(source_code)

return pytest.fixture(inner)

@classmethod
def _normalize_transform(cls, transform):
if hasattr(transform, "_pytestfixturefunction"):
return transform

if isinstance(transform, tvm.ir.transform.Pass):

def inner(self):
# pylint: disable=unused-argument
return transform

elif cls._is_method(transform):

def inner(self):
# pylint: disable=unused-argument
return transform(self)

else:

raise TypeError(
"Expected transform to be a tvm.ir.transform.Pass, or a method returning a Pass"
)

return pytest.fixture(inner)

@staticmethod
def _is_method(func):
sig = inspect.signature(func)
return "self" in sig.parameters

def test_compare(self, before, expected, transform):
"""Unit test to compare the expected TIR PrimFunc to actual"""

before_mod = tvm.IRModule.from_expr(before)

if inspect.isclass(expected) and issubclass(expected, Exception):
with pytest.raises(expected):
after_mod = transform(before_mod)

# This portion through pytest.fail isn't strictly
# necessary, but gives a better error message that
# includes the before/after.
after = after_mod["main"]
script = tvm.IRModule({"after": after, "before": before}).script()
pytest.fail(
msg=(
f"Expected {expected.__name__} to be raised from transformation, "
f"instead received TIR\n:{script}"
)
)

elif isinstance(expected, tvm.tir.PrimFunc):
after_mod = transform(before_mod)
after = after_mod["main"]

try:
tvm.ir.assert_structural_equal(after, expected)
except ValueError as err:
script = tvm.IRModule(
{"expected": expected, "after": after, "before": before}
).script()
raise ValueError(
f"TIR after transformation did not match expected:\n{script}"
) from err

else:
raise TypeError(
f"tvm.testing.CompareBeforeAfter requires the `expected` fixture "
f"to return either `Exception`, an `Exception` subclass, "
f"or an instance of `tvm.tir.PrimFunc`. "
f"Instead, received {type(exception)}."
)
38 changes: 2 additions & 36 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,31 +136,16 @@ def sls(n, d):
assert "if" not in str(stmt)


class BaseBeforeAfter:
def test_simplify(self):
before = self.before
before_mod = tvm.IRModule.from_expr(before)
after_mod = tvm.tir.transform.Simplify()(before_mod)
after = after_mod["main"]
expected = self.expected

try:
tvm.ir.assert_structural_equal(after, expected)
except ValueError as err:
script = tvm.IRModule({"expected": expected, "after": after, "before": before}).script()
raise ValueError(
f"Function after simplification did not match expected:\n{script}"
) from err
class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
transform = tvm.tir.transform.Simplify()


class TestLoadStoreNoop(BaseBeforeAfter):
"""Store of a value that was just read from the same location is a no-op."""

@T.prim_func
def before(A: T.Buffer[(1,), "float32"]):
A[0] = A[0]

@T.prim_func
def expected(A: T.Buffer[(1,), "float32"]):
T.evaluate(0)

Expand All @@ -174,11 +159,9 @@ class TestLoadStoreNoopAfterSimplify(BaseBeforeAfter):
regression.
"""

@T.prim_func
def before(A: T.Buffer[(1,), "float32"]):
A[0] = A[0] + (5.0 - 5.0)

@T.prim_func
def expected(A: T.Buffer[(1,), "float32"]):
T.evaluate(0)

Expand All @@ -191,14 +174,12 @@ class TestNestedCondition(BaseBeforeAfter):
constraint.
"""

@T.prim_func
def before(A: T.Buffer[(16,), "float32"]):
for i in T.serial(16):
if i == 5:
if i == 5:
A[i] = 0.0

@T.prim_func
def expected(A: T.Buffer[(16,), "float32"]):
for i in T.serial(16):
if i == 5:
Expand All @@ -212,14 +193,12 @@ class TestNestedProvableCondition(BaseBeforeAfter):
conditional.
"""

@T.prim_func
def before(A: T.Buffer[(16,), "float32"]):
for i in T.serial(16):
if i == 5:
if i < 7:
A[i] = 0.0

@T.prim_func
def expected(A: T.Buffer[(16,), "float32"]):
for i in T.serial(16):
if i == 5:
Expand All @@ -233,14 +212,12 @@ class TestNestedVarCondition(BaseBeforeAfter):
constraint.
"""

@T.prim_func
def before(A: T.Buffer[(16,), "float32"], n: T.int32):
for i in T.serial(16):
if i == n:
if i == n:
A[i] = 0.0

@T.prim_func
def expected(A: T.Buffer[(16,), "float32"], n: T.int32):
for i in T.serial(16):
if i == n:
Expand All @@ -256,7 +233,6 @@ class TestAlteredBufferContents(BaseBeforeAfter):
may not.
"""

@T.prim_func
def before(A: T.Buffer[(1,), "int32"], n: T.int32):
if A[0] == n:
A[0] = A[0] + 1
Expand All @@ -273,7 +249,6 @@ class TestNegationOfCondition(BaseBeforeAfter):
condition is known to be false.
"""

@T.prim_func
def before(A: T.Buffer[(16,), "int32"]):
for i in T.serial(16):
if i == 5:
Expand All @@ -282,7 +257,6 @@ def before(A: T.Buffer[(16,), "int32"]):
else:
A[i] = 1

@T.prim_func
def expected(A: T.Buffer[(16,), "int32"]):
for i in T.serial(16):
if i == 5:
Expand All @@ -298,7 +272,6 @@ class TestNegationOfNotEqual(BaseBeforeAfter):
``i==5`` as the negation of a literal constraint.
"""

@T.prim_func
def before(A: T.Buffer[(16,), "int32"]):
for i in T.serial(16):
if i != 5:
Expand All @@ -307,7 +280,6 @@ def before(A: T.Buffer[(16,), "int32"]):
else:
A[i] = 1

@T.prim_func
def expected(A: T.Buffer[(16,), "int32"]):
for i in T.serial(16):
if i != 5:
Expand All @@ -321,7 +293,6 @@ class TestNegationOfVarCondition(BaseBeforeAfter):
must rely on RewriteSimplifier recognizing the repeated literal.
"""

@T.prim_func
def before(A: T.Buffer[(16,), "int32"], n: T.int32):
for i in T.serial(16):
if i == n:
Expand All @@ -330,7 +301,6 @@ def before(A: T.Buffer[(16,), "int32"], n: T.int32):
else:
A[i] = 1

@T.prim_func
def expected(A: T.Buffer[(16,), "int32"], n: T.int32):
for i in T.serial(16):
if i == n:
Expand All @@ -346,14 +316,12 @@ class TestLiteralConstraintSplitBooleanAnd(BaseBeforeAfter):
the condition is to ensure we exercise RewriteSimplifier.
"""

@T.prim_func
def before(A: T.Buffer[(16, 16), "int32"], n: T.int32):
for i, j in T.grid(16, 16):
if i == n and j == n:
if i == n:
A[i, j] = 0

@T.prim_func
def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32):
for i, j in T.grid(16, 16):
if i == n and j == n:
Expand All @@ -371,7 +339,6 @@ class TestLiteralConstraintSplitBooleanOr(BaseBeforeAfter):
RewriteSimplifier.
"""

@T.prim_func
def before(A: T.Buffer[(16, 16), "int32"], n: T.int32):
for i, j in T.grid(16, 16):
if i == n or j == n:
Expand All @@ -382,7 +349,6 @@ def before(A: T.Buffer[(16, 16), "int32"], n: T.int32):
else:
A[i, j] = 2

@T.prim_func
def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32):
for i, j in T.grid(16, 16):
if i == n or j == n:
Expand Down
Loading

0 comments on commit 2b438ac

Please sign in to comment.