Skip to content

Commit

Permalink
[TIR] Allow VerifyWellFormed to accept IRModule (#15247)
Browse files Browse the repository at this point in the history
Previously, the calling code needed to iterate over all functions in a
module.  This commit adds an overload that accepts `const IRModule&`,
allowing it to be called more easily.  This also provides an API that
can be extended to validate behavior across an entire
IRModule (e.g. requiring that internal function calls have the correct
argument types).
  • Loading branch information
Lunderberg authored Jul 6, 2023
1 parent 88701dc commit 2f7c097
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
8 changes: 4 additions & 4 deletions python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,14 @@ def apply_prim_func_arg_and_result_memory_constraints(
)


def verify_well_formed(func: PrimFunc, assert_mode: bool = True) -> bool:
def verify_well_formed(obj: Union[PrimFunc, IRModule], assert_mode: bool = True) -> bool:
"""Verify if the given TIR is well-formed. The verification includes:
- Check if expressions not contain vars that is defined outside the block.
Parameters
----------
func: tvm.tir.PrimFunc
The function to be verified.
obj: Union[tvm.tir.PrimFunc, tvm.ir.IRModule]
The function or module to be verified.
assert_mode: bool
The indicator if it raises an error when the function is not well-formed.
Expand All @@ -366,7 +366,7 @@ def verify_well_formed(func: PrimFunc, assert_mode: bool = True) -> bool:
result: bool
Whether it is a well-formed TIR function.
"""
return _ffi_api.VerifyWellFormed(func, assert_mode) # type: ignore # pylint: disable=no-member
return _ffi_api.VerifyWellFormed(obj, assert_mode) # type: ignore # pylint: disable=no-member


def OOBChecker():
Expand Down
25 changes: 24 additions & 1 deletion src/tir/analysis/verify_well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/tir/stmt_functor.h>

#include "../ir/functor_common.h"
#include "tvm/ir/module.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -142,7 +143,29 @@ bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) {
return true;
}

TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed").set_body_typed(VerifyWellFormed);
bool VerifyWellFormed(const IRModule& mod, bool assert_mode) {
for (const auto& [gvar, base_func] : mod->functions) {
if (auto prim_func = base_func.as<PrimFunc>()) {
bool res = VerifyWellFormed(prim_func.value(), assert_mode);
if (!res) {
return false;
}
}
}
return true;
}

TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed")
.set_body_typed([](const ObjectRef& obj, bool assert_mode) {
if (auto opt = obj.as<PrimFunc>()) {
return VerifyWellFormed(opt.value(), assert_mode);
} else if (auto opt = obj.as<IRModule>()) {
return VerifyWellFormed(opt.value(), assert_mode);
} else {
LOG(FATAL) << "Expected VerifyWellFormed argument to be a PrimFunc or IRModule, but found "
<< obj->GetTypeKey();
}
});

} // namespace tir
} // namespace tvm
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def element_wise(
C[i, j] = B[i, j] * 2.0

assert tvm.tir.analysis.verify_well_formed(element_wise)
assert tvm.tir.analysis.verify_well_formed(tvm.IRModule.from_expr(element_wise))


def test_fail_use_out_loop_var():
Expand Down

0 comments on commit 2f7c097

Please sign in to comment.