diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 3c44d2bf1bc8..2991cc01fc92 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -36,6 +36,7 @@ from tvm.relay.backend import Executor, Runtime from tvm.micro import model_library_format as mlf from tvm.micro import export_model_library_format +from tvm.ir.instrument import pass_instrument from aot_test_utils import ( AOTTestModel, AOT_DEFAULT_RUNNER, @@ -1027,5 +1028,51 @@ def test_aot_codegen_checks_returns(): ) +def test_aot_uses_anf(): + """Checks that A-Normal Form is being used in the AOT lowering pipeline.""" + x = relay.var("x", shape=(1, 10, 10, 10)) + y = relay.var("y", shape=(1, 10, 10, 10)) + z = relay.add(x, y) + func = relay.Function([x, y], z) + + @pass_instrument + class CheckANFRuns: + def __init__(self): + self.did_run_anf = False + + def run_before_pass(self, _, info): + if info.name == "ToANormalForm": + self.did_run_anf = True + if info.name == "LowerTE": + assert self.did_run_anf, "ToANormalForm pass should run before LowerTE." + + check_run_anf = CheckANFRuns() + + model = AOTTestModel(module=IRModule.from_expr(func), inputs=None, outputs=None) + runtime = Runtime("crt") + executor = Executor( + "aot", + { + "workspace-byte-alignment": 8, + "interface-api": "c", + "unpacked-api": True, + }, + ) + config = {"tir.disable_vectorize": True} + + with tvm.transform.PassContext(opt_level=3, config=config, instruments=[check_run_anf]): + tvm.relay.build( + model.module, + tvm.target.Target("c"), + executor=executor, + runtime=runtime, + workspace_memory_pools=None, + params=model.params, + mod_name=model.name, + ) + + assert check_run_anf.did_run_anf, "Expected ToANormalForm pass to have run." + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))