Skip to content

Commit

Permalink
Add check to ensure ANF runs in AOT
Browse files Browse the repository at this point in the history
Change-Id: I8de2bd19c7c17057e2bc89f6a68595780c2e9433
  • Loading branch information
lhutton1 committed May 5, 2022
1 parent 7d2fd9f commit 02279c2
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/python/relay/aot/test_crt_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from tvm.relay.backend import Executor, Runtime
from tvm.micro import model_library_format as mlf
from tvm.micro import export_model_library_format
from tvm.ir.instrument import pass_instrument
from aot_test_utils import (
AOTTestModel,
AOT_DEFAULT_RUNNER,
Expand Down Expand Up @@ -1027,5 +1028,51 @@ def test_aot_codegen_checks_returns():
)


def test_aot_uses_anf():
"""Checks that A-Normal Form is being used in the AOT lowering pipeline."""
x = relay.var("x", shape=(1, 10, 10, 10))
y = relay.var("y", shape=(1, 10, 10, 10))
z = relay.add(x, y)
func = relay.Function([x, y], z)

@pass_instrument
class CheckANFRuns:
def __init__(self):
self.did_run_anf = False

def run_before_pass(self, _, info):
if info.name == "ToANormalForm":
self.did_run_anf = True
if info.name == "LowerTE":
assert self.did_run_anf, "ToANormalForm pass should run before LowerTE."

check_run_anf = CheckANFRuns()

model = AOTTestModel(module=IRModule.from_expr(func), inputs=None, outputs=None)
runtime = Runtime("crt")
executor = Executor(
"aot",
{
"workspace-byte-alignment": 8,
"interface-api": "c",
"unpacked-api": True,
},
)
config = {"tir.disable_vectorize": True}

with tvm.transform.PassContext(opt_level=3, config=config, instruments=[check_run_anf]):
tvm.relay.build(
model.module,
tvm.target.Target("c"),
executor=executor,
runtime=runtime,
workspace_memory_pools=None,
params=model.params,
mod_name=model.name,
)

assert check_run_anf.did_run_anf, "Expected ToANormalForm pass to have run."


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 02279c2

Please sign in to comment.