Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft][TVMScript][Unittest] Validate round-trip for each TIR lowering step #14486

Closed
wants to merge 13 commits into from

Conversation

Lunderberg
Copy link
Contributor

Prior to this PR, some of the IRModule transformations used during lowering can produce TIR that cannot be round-tripped through TVMScript. Since TVMScript is the default method for printing all TIR, this can make it difficult to identify which pass has introduced a breaking change.

The first commit in this PR adds a test that checks whether a module can correctly round-trip from TIR to TVMScript and back, for each lowering pass used in tvm.build. Each of the following independent commits resolves one of the issues found by the more general test, and adds a specific test for the round-trip failure that it resolves.

My plan is to break out these independent commits will be into separate PRs, for ease of review. After all the component PRs have landed or found an alternative fix, this PR will be rebased to only consist of the new test itself, and will be opened for review.

Lunderberg and others added 13 commits April 4, 2023 10:21
Prior to this PR, some of the IRModule transformations used during
lowering can produce TIR that cannot be round-tripped through
TVMScript.  Since TVMScript is the default method for printing all
TIR, this can make it difficult to identify which pass has introduced
a breaking change.

This PR adds a test that checks whether a module can correctly
round-trip from TIR to TVMScript and back, for each lowering pass used
in `tvm.build`.
This PR adds the TVMScript parser/ir_builder support based on the
blockbuilder.  This commit contains the non-relax portions from
apache#13932.

Co-authored-by: Ruihang Lai <ruihangl@cs.cmu.edu>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Tianqi Chen <tianqi.tchen@gmail.com>
Co-authored-by: Yuchen Jin <yuchenj@cs.washington.edu>
Co-authored-by: Steven S. Lyubomirsky <slyubomirsky@gmail.com>
Co-authored-by: Yong Wu <yongcale@gmail.com>
This is an upstreaming of the non-relax portions of
apache#14132, including a unit test
specically to validate `I.module_attrs`.
Prior to this commit, the python API `tvm.tir.op.tvm_struct_set`
defined the return type of `builtin::tvm_struct_set` as `"handle"`,
while the C++ API `tvm::tir::TVMStructSet` defined the return type as
`DataType::Int(32)`.  The data type used for this builtin has no
effect, because no value is returned.  However, this discrepancy can
cause failure to roundtrip through TVMScript.

This commit updates the Python API to use `"int32"`, for consistency
with the C++ API and with `CodeGenCPU`.
Prior to this commit, `SeqStmt::Flatten` could accept an arbitrary
number of arguments, where each argument was of type `const tir::Stmt&`
or an iterable.  However, if `SeqStmt::Flatten` were passed a subclass
of `tir::Stmt`, the templated overload was selected as the better
match.

This commit rewrites `SeqStmt::Flatten` using C++17's `"constexpr if"`
feature, to handle cases of `SeqStmt`, superclasses of `SeqStmt`, and
other subclasses of `Stmt`.
`tir::StringImm` can round-trip through TVMScript when used in a
context that requires a PrimExpr, such as the arguments of a
`tir::Call`.  However, contexts that only require a `ObjectRef`, such
as the `AttrStmtNode::node`, use the same TVMScript representation as
`"string_value"`, but are parsed `tvm::String` instances.

This commit updates `MakePackedAPI` to use `String` instead of
`StringImm` in its default value for `AttrStmtNode::node`.
Previously, SeqStmt could be nested, making a distinction between the
nested `SeqStmt({SeqStmt({a,b}), c})` and the flat `SeqStmt({a,b,c})`,
even though the two are semantically equivalent.  This also caused an
issue with round-trips through TVMScript, which does not preserve this
distinction.

This commit updates the `SeqStmt` constructor and the `SeqStmt`
visitor in `StmtMutator` to flatten nested sequential statements
provided.
This commit adds a templated overload to `SEqualReducer::operator()`
that accepts a lambda function to update the path of the LHS and RHS
of the comparison.

```c++
// Usage prior to this utility function
if (equal.IsPathTracingEnabled()) {
  const ObjectPathPair& self_paths = equal.GetCurrentObjectPaths();
  ObjectPathPair attr_paths = {self_paths->lhs_path->Attr("value"),
                               self_paths->rhs_path->Attr("value")};
  if (!equal(kv.second, other->Lookup(kv.first->name_hint), attr_paths)) return false;
} else {
  if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
}

// Usage after this utility function
if (!equal(kv.second, other->Lookup(kv.first->name_hint),
           [](const auto& path) { return path->Attr("value"); })) {
  return false;
}
```
Prior to this commit, `PrimFuncPass` directly removed empty `PrimFunc`
objects from the module's `Map<GlobalVar, BaseFunc> functions`.
Because it didn't update the `global_var_map_` as well, these two maps
could become out of sync.  Since the `global_var_map_` is checked as
part of `StructuralEqual()`, but isn't displayed when printing to
TVMScript, this can result in identical printouts being flagged as
non-identical.

This commit updates `PrimFuncPass` to call the `IRModuleNode::Remove`
method, which updates both the `functions` and `global_var_map_`
variables.
Previously, `kDeviceThreadAxis` defined the IterVar to be used for
each thread/block axis, and `kUseDynamicSharedMemoryTag` defined
whether dynamic memory allocations exist, which are primarily used to
produce a list of strings by `tvm::codegen::ExtractFuncInfo`.  Because
`kDeviceThreadAxis` is a `Array<IterVar>`, the IterVar is used prior
to its definition site at `tir::attr::thread_extent`, which results in
errors when attempting to round-trip through TVMScript.

This commit replaces these attributes with
`attr::kKernelLaunchParams`, which directly contains the kernel launch
parameters.  These are expressed as an `Array<String>`, allowing the
generated TVMScript to successfully round-trip.
When passes create new PrimFuncs, such as when `tir.SplitHostDevice`
separates out a `tir::Stmt` into an independent function, the
parameters of these new function may alias existing variable
definitions.  While this is well-defined, because variable definitions
are not shared across function boundaries, it can give false
discrepancies from `tvm.ir.assert_structural_equal`.

This commit implements `tvm::tir::transform::ConvertSSA`, which
ensures unique variable declaration locations across an entire module.
Avoid duplicate variable defitions between the host and device
PrimFunc.
@tvm-bot
Copy link
Collaborator

tvm-bot commented Apr 4, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@Lunderberg
Copy link
Contributor Author

Closing this PR, as it is no longer relevant to the majority of TVM development.

@Lunderberg Lunderberg closed this Sep 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants