Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Node] Utility methods for ObjectPathPair handling #14498

Merged
merged 4 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions include/tvm/node/structural_equal.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,24 +191,53 @@ class SEqualReducer {

/*!
* \brief Reduce condition to comparison of two attribute values.
*
* \param lhs The left operand.
*
* \param rhs The right operand.
*
* \param paths The paths to the LHS and RHS operands. If
* unspecified, will attempt to identify the attribute's address
* within the most recent ObjectRef. In general, the paths only
* require explicit handling for computed parameters
* (e.g. `array.size()`)
*
* \return the immediate check result.
*/
bool operator()(const double& lhs, const double& rhs) const;
bool operator()(const int64_t& lhs, const int64_t& rhs) const;
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const;
bool operator()(const int& lhs, const int& rhs) const;
bool operator()(const bool& lhs, const bool& rhs) const;
bool operator()(const std::string& lhs, const std::string& rhs) const;
bool operator()(const DataType& lhs, const DataType& rhs) const;
bool operator()(const double& lhs, const double& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const int64_t& lhs, const int64_t& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const uint64_t& lhs, const uint64_t& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const int& lhs, const int& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const bool& lhs, const bool& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const std::string& lhs, const std::string& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const DataType& lhs, const DataType& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;

template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
bool operator()(const ENum& lhs, const ENum& rhs) const {
bool operator()(const ENum& lhs, const ENum& rhs,
Optional<ObjectPathPair> paths = NullOpt) const {
using Underlying = typename std::underlying_type<ENum>::type;
static_assert(std::is_same<Underlying, int>::value,
"Enum must have `int` as the underlying type");
return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs);
return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs, paths);
}

template <typename T, typename Callable,
typename = std::enable_if_t<
std::is_same_v<std::invoke_result_t<Callable, const ObjectPath&>, ObjectPath>>>
bool operator()(const T& lhs, const T& rhs, const Callable& callable) {
if (IsPathTracingEnabled()) {
ObjectPathPair current_paths = GetCurrentObjectPaths();
ObjectPathPair new_paths = {callable(current_paths->lhs_path),
callable(current_paths->rhs_path)};
return (*this)(lhs, rhs, new_paths);
} else {
return (*this)(lhs, rhs);
}
}

/*!
Expand Down Expand Up @@ -310,7 +339,8 @@ class SEqualReducer {
void RecordMismatchPaths(const ObjectPathPair& paths) const;

private:
bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const;
bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address,
Optional<ObjectPathPair> paths = NullOpt) const;

bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const ObjectPathPair* paths) const;
Expand All @@ -321,7 +351,8 @@ class SEqualReducer {

template <typename T>
static bool CompareAttributeValues(const T& lhs, const T& rhs,
const PathTracingData* tracing_data);
const PathTracingData* tracing_data,
Optional<ObjectPathPair> paths = NullOpt);

/*! \brief Internal class pointer. */
Handler* handler_ = nullptr;
Expand Down
64 changes: 32 additions & 32 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,46 +63,46 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
}

bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
if (!equal(this->attrs, other->attrs)) return false;
if (!equal(this->attrs, other->attrs, [](const auto& path) { return path->Attr("attrs"); })) {
return false;
}

if (equal.IsPathTracingEnabled()) {
if ((functions.size() != other->functions.size()) ||
(type_definitions.size() != other->type_definitions.size())) {
return false;
}
}

if (functions.size() != other->functions.size()) return false;
// Update GlobalVar remap
// Define remaps for GlobalVar and GlobalTypeVar based on their
// string name. Early bail-out is only performed when path-tracing
// is disabled, as the later equality checks on the member variables
// will provide better error messages.
for (const auto& gv : this->GetGlobalVars()) {
if (!other->ContainGlobalVar(gv->name_hint)) return false;
if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false;
if (other->ContainGlobalVar(gv->name_hint)) {
if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false;
} else if (!equal.IsPathTracingEnabled()) {
return false;
}
}
// Checking functions
for (const auto& kv : this->functions) {
if (equal.IsPathTracingEnabled()) {
const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
ObjectPathPair func_paths = {obj_path_pair->lhs_path->Attr("functions")->MapValue(kv.first),
obj_path_pair->rhs_path->Attr("functions")
->MapValue(other->GetGlobalVar(kv.first->name_hint))};
if (!equal(kv.second, other->Lookup(kv.first->name_hint), func_paths)) return false;
} else {
if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
for (const auto& gtv : this->GetGlobalTypeVars()) {
if (other->ContainGlobalTypeVar(gtv->name_hint)) {
if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false;
} else if (!equal.IsPathTracingEnabled()) {
return false;
}
}

if (type_definitions.size() != other->type_definitions.size()) return false;
// Update GlobalTypeVar remap
for (const auto& gtv : this->GetGlobalTypeVars()) {
if (!other->ContainGlobalTypeVar(gtv->name_hint)) return false;
if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false;
// Checking functions and type definitions
if (!equal(this->functions, other->functions,
[](const auto& path) { return path->Attr("functions"); })) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is func_path (adding function name to the path) handled like before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, yes. The IRModule's handler only needs to handle appending .functions to the path, with the function name added by the equality check on Map. This results in a path of <root>.functions[I.GlobalVar("func")].body....

I've added a unit test that verifies that the error message includes the function name, as that's definitely a good thing to verify for user-friendliness.

return false;
}
// Checking type_definitions
for (const auto& kv : this->type_definitions) {
if (equal.IsPathTracingEnabled()) {
const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
ObjectPathPair type_paths = {
obj_path_pair->lhs_path->Attr("type_definitions")->MapValue(kv.first),
obj_path_pair->rhs_path->Attr("type_definitions")
->MapValue(other->GetGlobalTypeVar(kv.first->name_hint))};
if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), type_paths)) return false;
} else {
if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false;
}
if (!equal(this->type_definitions, other->type_definitions,
[](const auto& path) { return path->Attr("type_definitions"); })) {
return false;
}

return true;
}

Expand Down
51 changes: 36 additions & 15 deletions src/node/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,51 +109,72 @@ bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {

template <typename T>
/* static */ bool SEqualReducer::CompareAttributeValues(const T& lhs, const T& rhs,
const PathTracingData* tracing_data) {
const PathTracingData* tracing_data,
Optional<ObjectPathPair> paths) {
if (BaseValueEqual()(lhs, rhs)) {
return true;
} else {
GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data);
return false;
}

if (tracing_data && !tracing_data->first_mismatch->defined()) {
if (paths) {
*tracing_data->first_mismatch = paths.value();
} else {
GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data);
}
}
return false;
}

bool SEqualReducer::operator()(const double& lhs, const double& rhs) const {
bool SEqualReducer::operator()(const double& lhs, const double& rhs,
Optional<ObjectPathPair> paths) const {
return CompareAttributeValues(lhs, rhs, tracing_data_);
}

bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs) const {
bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs,
Optional<ObjectPathPair> paths) const {
return CompareAttributeValues(lhs, rhs, tracing_data_);
}

bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs) const {
bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs,
Optional<ObjectPathPair> paths) const {
return CompareAttributeValues(lhs, rhs, tracing_data_);
}

bool SEqualReducer::operator()(const int& lhs, const int& rhs) const {
bool SEqualReducer::operator()(const int& lhs, const int& rhs,
Optional<ObjectPathPair> paths) const {
return CompareAttributeValues(lhs, rhs, tracing_data_);
}

bool SEqualReducer::operator()(const bool& lhs, const bool& rhs) const {
bool SEqualReducer::operator()(const bool& lhs, const bool& rhs,
Optional<ObjectPathPair> paths) const {
return CompareAttributeValues(lhs, rhs, tracing_data_);
}

bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs) const {
bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs,
Optional<ObjectPathPair> paths) const {
return CompareAttributeValues(lhs, rhs, tracing_data_);
}

bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs) const {
bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs,
Optional<ObjectPathPair> paths) const {
return CompareAttributeValues(lhs, rhs, tracing_data_);
}

bool SEqualReducer::EnumAttrsEqual(int lhs, int rhs, const void* lhs_address,
const void* rhs_address) const {
const void* rhs_address, Optional<ObjectPathPair> paths) const {
if (lhs == rhs) {
return true;
} else {
GetPathsFromAttrAddressesAndStoreMismatch(lhs_address, rhs_address, tracing_data_);
return false;
}

if (tracing_data_ && !tracing_data_->first_mismatch->defined()) {
if (paths) {
*tracing_data_->first_mismatch = paths.value();
} else {
GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data_);
}
}

return false;
}

const ObjectPathPair& SEqualReducer::GetCurrentObjectPaths() const {
Expand Down
35 changes: 26 additions & 9 deletions tests/python/unittest/test_tir_structural_equal_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest
from tvm import te
from tvm.runtime import ObjectPath
from tvm.script import tir as T, ir as I


def consistent_equal(x, y, map_free_vars=False):
Expand Down Expand Up @@ -394,13 +395,29 @@ def test_seq_length_mismatch():
assert rhs_path == expected_rhs_path


def test_ir_module_equal():
def generate(n: int):
@I.ir_module
class module:
@T.prim_func
def func(A: T.Buffer(1, "int32")):
for i in range(n):
A[0] = A[0] + 1

return module

# Equivalent IRModules should compare as equivalent, even though
# they have distinct GlobalVars, and GlobalVars usually compare by
# reference equality.
tvm.ir.assert_structural_equal(generate(16), generate(16))

# When there is a difference, the location should include the
# function name that caused the failure.
with pytest.raises(ValueError) as err:
tvm.ir.assert_structural_equal(generate(16), generate(32))

assert '<root>.functions[I.GlobalVar("func")].body.extent.value' in err.value.args[0]


if __name__ == "__main__":
test_exprs()
test_prim_func()
test_attrs()
test_array()
test_env_func()
test_stmt()
test_buffer_storage_scope()
test_buffer_load_store()
test_while()
tvm.testing.main()