From 49a0daec995fbd308e3f757f3cd05552df3f7471 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 3 Apr 2023 12:35:20 -0500 Subject: [PATCH 1/4] [Node] Utility methods for ObjectPathPair handling This commit adds a templated overload to `SEqualReducer::operator()` that accepts a lambda function to update the path of the LHS and RHS of the comparison. ```c++ // Usage prior to this utility function if (equal.IsPathTracingEnabled()) { const ObjectPathPair& self_paths = equal.GetCurrentObjectPaths(); ObjectPathPair attr_paths = {self_paths->lhs_path->Attr("value"), self_paths->rhs_path->Attr("value")}; if (!equal(this->value, other->value, attr_paths)) return false; } else { if (!equal(this->value, other->value)) return false; } // Usage after this utility function if (!equal(this->value, other->value, [](const auto& path) { return path->Attr("value"); })) { return false; } ``` --- include/tvm/node/structural_equal.h | 44 +++++++++++++++----- src/ir/module.cc | 64 ++++++++++++++--------------- src/node/structural_equal.cc | 51 ++++++++++++++++------- 3 files changed, 101 insertions(+), 58 deletions(-) diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 5bd76404a998..cff1e775072f 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -195,20 +195,40 @@ class SEqualReducer { * \param rhs The right operand. * \return the immediate check result. */ - bool operator()(const double& lhs, const double& rhs) const; - bool operator()(const int64_t& lhs, const int64_t& rhs) const; - bool operator()(const uint64_t& lhs, const uint64_t& rhs) const; - bool operator()(const int& lhs, const int& rhs) const; - bool operator()(const bool& lhs, const bool& rhs) const; - bool operator()(const std::string& lhs, const std::string& rhs) const; - bool operator()(const DataType& lhs, const DataType& rhs) const; + bool operator()(const double& lhs, const double& rhs, + Optional paths = NullOpt) const; + bool operator()(const int64_t& lhs, const int64_t& rhs, + Optional paths = NullOpt) const; + bool operator()(const uint64_t& lhs, const uint64_t& rhs, + Optional paths = NullOpt) const; + bool operator()(const int& lhs, const int& rhs, Optional paths = NullOpt) const; + bool operator()(const bool& lhs, const bool& rhs, Optional paths = NullOpt) const; + bool operator()(const std::string& lhs, const std::string& rhs, + Optional paths = NullOpt) const; + bool operator()(const DataType& lhs, const DataType& rhs, + Optional paths = NullOpt) const; template ::value>::type> - bool operator()(const ENum& lhs, const ENum& rhs) const { + bool operator()(const ENum& lhs, const ENum& rhs, + Optional paths = NullOpt) const { using Underlying = typename std::underlying_type::type; static_assert(std::is_same::value, "Enum must have `int` as the underlying type"); - return EnumAttrsEqual(static_cast(lhs), static_cast(rhs), &lhs, &rhs); + return EnumAttrsEqual(static_cast(lhs), static_cast(rhs), &lhs, &rhs, paths); + } + + template , ObjectPath>>> + bool operator()(const T& lhs, const T& rhs, const Callable& callable) { + if (IsPathTracingEnabled()) { + ObjectPathPair current_paths = GetCurrentObjectPaths(); + ObjectPathPair new_paths = {callable(current_paths->lhs_path), + callable(current_paths->rhs_path)}; + return (*this)(lhs, rhs, new_paths); + } else { + return (*this)(lhs, rhs); + } } /*! @@ -310,7 +330,8 @@ class SEqualReducer { void RecordMismatchPaths(const ObjectPathPair& paths) const; private: - bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const; + bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address, + Optional paths = NullOpt) const; bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, const ObjectPathPair* paths) const; @@ -321,7 +342,8 @@ class SEqualReducer { template static bool CompareAttributeValues(const T& lhs, const T& rhs, - const PathTracingData* tracing_data); + const PathTracingData* tracing_data, + Optional paths = NullOpt); /*! \brief Internal class pointer. */ Handler* handler_ = nullptr; diff --git a/src/ir/module.cc b/src/ir/module.cc index 7a973da29dfa..4d5bebf70894 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -63,46 +63,46 @@ IRModule::IRModule(tvm::Map functions, } bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { - if (!equal(this->attrs, other->attrs)) return false; + if (!equal(this->attrs, other->attrs, [](const auto& path) { return path->Attr("attrs"); })) { + return false; + } + + if (equal.IsPathTracingEnabled()) { + if ((functions.size() != other->functions.size()) || + (type_definitions.size() != other->type_definitions.size())) { + return false; + } + } - if (functions.size() != other->functions.size()) return false; - // Update GlobalVar remap + // Define remaps for GlobalVar and GlobalTypeVar based on their + // string name. Early bail-out is only performed when path-tracing + // is disabled, as the later equality checks on the member variables + // will provide better error messages. for (const auto& gv : this->GetGlobalVars()) { - if (!other->ContainGlobalVar(gv->name_hint)) return false; - if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false; + if (other->ContainGlobalVar(gv->name_hint)) { + if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false; + } else if (!equal.IsPathTracingEnabled()) { + return false; + } } - // Checking functions - for (const auto& kv : this->functions) { - if (equal.IsPathTracingEnabled()) { - const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths(); - ObjectPathPair func_paths = {obj_path_pair->lhs_path->Attr("functions")->MapValue(kv.first), - obj_path_pair->rhs_path->Attr("functions") - ->MapValue(other->GetGlobalVar(kv.first->name_hint))}; - if (!equal(kv.second, other->Lookup(kv.first->name_hint), func_paths)) return false; - } else { - if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false; + for (const auto& gtv : this->GetGlobalTypeVars()) { + if (other->ContainGlobalTypeVar(gtv->name_hint)) { + if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false; + } else if (!equal.IsPathTracingEnabled()) { + return false; } } - if (type_definitions.size() != other->type_definitions.size()) return false; - // Update GlobalTypeVar remap - for (const auto& gtv : this->GetGlobalTypeVars()) { - if (!other->ContainGlobalTypeVar(gtv->name_hint)) return false; - if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false; + // Checking functions and type definitions + if (!equal(this->functions, other->functions, + [](const auto& path) { return path->Attr("functions"); })) { + return false; } - // Checking type_definitions - for (const auto& kv : this->type_definitions) { - if (equal.IsPathTracingEnabled()) { - const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths(); - ObjectPathPair type_paths = { - obj_path_pair->lhs_path->Attr("type_definitions")->MapValue(kv.first), - obj_path_pair->rhs_path->Attr("type_definitions") - ->MapValue(other->GetGlobalTypeVar(kv.first->name_hint))}; - if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), type_paths)) return false; - } else { - if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false; - } + if (!equal(this->type_definitions, other->type_definitions, + [](const auto& path) { return path->Attr("type_definitions"); })) { + return false; } + return true; } diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 42726af9859a..66a347f6b8ba 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -109,51 +109,72 @@ bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { template /* static */ bool SEqualReducer::CompareAttributeValues(const T& lhs, const T& rhs, - const PathTracingData* tracing_data) { + const PathTracingData* tracing_data, + Optional paths) { if (BaseValueEqual()(lhs, rhs)) { return true; - } else { - GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data); - return false; } + + if (tracing_data && !tracing_data->first_mismatch->defined()) { + if (paths) { + *tracing_data->first_mismatch = paths.value(); + } else { + GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data); + } + } + return false; } -bool SEqualReducer::operator()(const double& lhs, const double& rhs) const { +bool SEqualReducer::operator()(const double& lhs, const double& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs) const { +bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs) const { +bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const int& lhs, const int& rhs) const { +bool SEqualReducer::operator()(const int& lhs, const int& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const bool& lhs, const bool& rhs) const { +bool SEqualReducer::operator()(const bool& lhs, const bool& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs) const { +bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } -bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs) const { +bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs, + Optional paths) const { return CompareAttributeValues(lhs, rhs, tracing_data_); } bool SEqualReducer::EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, - const void* rhs_address) const { + const void* rhs_address, Optional paths) const { if (lhs == rhs) { return true; - } else { - GetPathsFromAttrAddressesAndStoreMismatch(lhs_address, rhs_address, tracing_data_); - return false; } + + if (tracing_data_ && !tracing_data_->first_mismatch->defined()) { + if (paths) { + *tracing_data_->first_mismatch = paths.value(); + } else { + GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data_); + } + } + + return false; } const ObjectPathPair& SEqualReducer::GetCurrentObjectPaths() const { From 83eaea0b373d0447f88e9535dfd5a9b270747d6a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 4 Apr 2023 12:27:04 -0500 Subject: [PATCH 2/4] Unit test testing that discrepant path includes the PrimFunc's name --- .../test_tir_structural_equal_hash.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/tests/python/unittest/test_tir_structural_equal_hash.py b/tests/python/unittest/test_tir_structural_equal_hash.py index 4bb13ed77ad8..103d89c908a9 100644 --- a/tests/python/unittest/test_tir_structural_equal_hash.py +++ b/tests/python/unittest/test_tir_structural_equal_hash.py @@ -19,6 +19,7 @@ import pytest from tvm import te from tvm.runtime import ObjectPath +from tvm.script import tir as T, ir as I def consistent_equal(x, y, map_free_vars=False): @@ -394,13 +395,29 @@ def test_seq_length_mismatch(): assert rhs_path == expected_rhs_path +def test_ir_module_equal(): + def generate(n: int): + @I.ir_module + class module: + @T.prim_func + def func(A: T.Buffer(1, "int32")): + for i in range(n): + A[0] = A[0] + 1 + + return module + + # Equivalent IRModules should compare as equivalent, even though + # they have distinct GlobalVars, and GlobalVars usually compare by + # reference equality. + tvm.ir.assert_structural_equal(generate(16), generate(16)) + + # When there is a difference, the location should include the + # function name that caused the failure. + with pytest.raises(ValueError) as err: + tvm.ir.assert_structural_equal(generate(16), generate(32)) + + assert '.functions[I.GlobalVar("func")].body.extent.value' in err.value + + if __name__ == "__main__": - test_exprs() - test_prim_func() - test_attrs() - test_array() - test_env_func() - test_stmt() - test_buffer_storage_scope() - test_buffer_load_store() - test_while() + tvm.testing.main() From fa0236c793fc57127d07f3b65de6c6dae406cf9e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 4 Apr 2023 12:49:07 -0500 Subject: [PATCH 3/4] Updated docstring to resolve linting error --- include/tvm/node/structural_equal.h | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index cff1e775072f..acc362758a7c 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -191,8 +191,17 @@ class SEqualReducer { /*! * \brief Reduce condition to comparison of two attribute values. + * * \param lhs The left operand. + * * \param rhs The right operand. + * + * \param paths The paths to the LHS and RHS operands. If + * unspecified, will attempt to identify the attribute's address + * within the most recent ObjectRef. In general, the paths only + * require explicit handling for computed parameters + * (e.g. `array.size()`) + * * \return the immediate check result. */ bool operator()(const double& lhs, const double& rhs, From 262ddbaa793bb0b1708b4a2b90f851715dc49ec8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 5 Apr 2023 09:49:16 -0500 Subject: [PATCH 4/4] Fixed where to look for error message in unit test --- tests/python/unittest/test_tir_structural_equal_hash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_structural_equal_hash.py b/tests/python/unittest/test_tir_structural_equal_hash.py index 103d89c908a9..eca78d649b85 100644 --- a/tests/python/unittest/test_tir_structural_equal_hash.py +++ b/tests/python/unittest/test_tir_structural_equal_hash.py @@ -416,7 +416,7 @@ def func(A: T.Buffer(1, "int32")): with pytest.raises(ValueError) as err: tvm.ir.assert_structural_equal(generate(16), generate(32)) - assert '.functions[I.GlobalVar("func")].body.extent.value' in err.value + assert '.functions[I.GlobalVar("func")].body.extent.value' in err.value.args[0] if __name__ == "__main__":