Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Node] Utility methods for ObjectPathPair handling #14498

Merged
merged 4 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions include/tvm/node/structural_equal.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,20 +195,40 @@ class SEqualReducer {
* \param rhs The right operand.
* \return the immediate check result.
*/
bool operator()(const double& lhs, const double& rhs) const;
bool operator()(const int64_t& lhs, const int64_t& rhs) const;
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const;
bool operator()(const int& lhs, const int& rhs) const;
bool operator()(const bool& lhs, const bool& rhs) const;
bool operator()(const std::string& lhs, const std::string& rhs) const;
bool operator()(const DataType& lhs, const DataType& rhs) const;
bool operator()(const double& lhs, const double& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const int64_t& lhs, const int64_t& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const uint64_t& lhs, const uint64_t& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const int& lhs, const int& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const bool& lhs, const bool& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const std::string& lhs, const std::string& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;
bool operator()(const DataType& lhs, const DataType& rhs,
Optional<ObjectPathPair> paths = NullOpt) const;

template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
bool operator()(const ENum& lhs, const ENum& rhs) const {
bool operator()(const ENum& lhs, const ENum& rhs,
Optional<ObjectPathPair> paths = NullOpt) const {
using Underlying = typename std::underlying_type<ENum>::type;
static_assert(std::is_same<Underlying, int>::value,
"Enum must have `int` as the underlying type");
return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs);
return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs, paths);
}

template <typename T, typename Callable,
typename = std::enable_if_t<
std::is_same_v<std::invoke_result_t<Callable, const ObjectPath&>, ObjectPath>>>
bool operator()(const T& lhs, const T& rhs, const Callable& callable) {
if (IsPathTracingEnabled()) {
ObjectPathPair current_paths = GetCurrentObjectPaths();
ObjectPathPair new_paths = {callable(current_paths->lhs_path),
callable(current_paths->rhs_path)};
return (*this)(lhs, rhs, new_paths);
} else {
return (*this)(lhs, rhs);
}
}

/*!
Expand Down Expand Up @@ -310,7 +330,8 @@ class SEqualReducer {
void RecordMismatchPaths(const ObjectPathPair& paths) const;

private:
bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const;
bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address,
Optional<ObjectPathPair> paths = NullOpt) const;

bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const ObjectPathPair* paths) const;
Expand All @@ -321,7 +342,8 @@ class SEqualReducer {

template <typename T>
static bool CompareAttributeValues(const T& lhs, const T& rhs,
const PathTracingData* tracing_data);
const PathTracingData* tracing_data,
Optional<ObjectPathPair> paths = NullOpt);

/*! \brief Internal class pointer. */
Handler* handler_ = nullptr;
Expand Down
64 changes: 32 additions & 32 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,46 +63,46 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
}

bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
if (!equal(this->attrs, other->attrs)) return false;
if (!equal(this->attrs, other->attrs, [](const auto& path) { return path->Attr("attrs"); })) {
return false;
}

if (equal.IsPathTracingEnabled()) {
if ((functions.size() != other->functions.size()) ||
(type_definitions.size() != other->type_definitions.size())) {
return false;
}
}

if (functions.size() != other->functions.size()) return false;
// Update GlobalVar remap
// Define remaps for GlobalVar and GlobalTypeVar based on their
// string name. Early bail-out is only performed when path-tracing
// is disabled, as the later equality checks on the member variables
// will provide better error messages.
for (const auto& gv : this->GetGlobalVars()) {
if (!other->ContainGlobalVar(gv->name_hint)) return false;
if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false;
if (other->ContainGlobalVar(gv->name_hint)) {
if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false;
} else if (!equal.IsPathTracingEnabled()) {
return false;
}
}
// Checking functions
for (const auto& kv : this->functions) {
if (equal.IsPathTracingEnabled()) {
const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
ObjectPathPair func_paths = {obj_path_pair->lhs_path->Attr("functions")->MapValue(kv.first),
obj_path_pair->rhs_path->Attr("functions")
->MapValue(other->GetGlobalVar(kv.first->name_hint))};
if (!equal(kv.second, other->Lookup(kv.first->name_hint), func_paths)) return false;
} else {
if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
for (const auto& gtv : this->GetGlobalTypeVars()) {
if (other->ContainGlobalTypeVar(gtv->name_hint)) {
if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false;
} else if (!equal.IsPathTracingEnabled()) {
return false;
}
}

if (type_definitions.size() != other->type_definitions.size()) return false;
// Update GlobalTypeVar remap
for (const auto& gtv : this->GetGlobalTypeVars()) {
if (!other->ContainGlobalTypeVar(gtv->name_hint)) return false;
if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false;
// Checking functions and type definitions
if (!equal(this->functions, other->functions,
[](const auto& path) { return path->Attr("functions"); })) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is func_path (adding function name to the path) handled like before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, yes. The IRModule's handler only needs to handle appending .functions to the path, with the function name added by the equality check on Map. This results in a path of <root>.functions[I.GlobalVar("func")].body....

I've added a unit test that verifies that the error message includes the function name, as that's definitely a good thing to verify for user-friendliness.

return false;
}
// Checking type_definitions
for (const auto& kv : this->type_definitions) {
if (equal.IsPathTracingEnabled()) {
const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
ObjectPathPair type_paths = {
obj_path_pair->lhs_path->Attr("type_definitions")->MapValue(kv.first),
obj_path_pair->rhs_path->Attr("type_definitions")
->MapValue(other->GetGlobalTypeVar(kv.first->name_hint))};
if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), type_paths)) return false;
} else {
if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false;
}
if (!equal(this->type_definitions, other->type_definitions,
[](const auto& path) { return path->Attr("type_definitions"); })) {
return false;
}

return true;
}

Expand Down
51 changes: 36 additions & 15 deletions src/node/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,51 +109,72 @@ bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {

template <typename T>
/* static */ bool SEqualReducer::CompareAttributeValues(const T& lhs, const T& rhs,
const PathTracingData* tracing_data) {
const PathTracingData* tracing_data,
Optional<ObjectPathPair> paths) {
if (BaseValueEqual()(lhs, rhs)) {
return true;
} else {
GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data);
return false;
}

if (tracing_data && !tracing_data->first_mismatch->defined()) {
if (paths) {
*tracing_data->first_mismatch = paths.value();
} else {
GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data);
}
}
return false;
}

bool SEqualReducer::operator()(const double& lhs, const double& rhs) const {
bool SEqualReducer::operator()(const double& lhs, const double& rhs,
Optional<ObjectPathPair> paths) const {
return CompareAttributeValues(lhs, rhs, tracing_data_);
}

bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs) const {
bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs,
Optional<ObjectPathPair> paths) const {
return CompareAttributeValues(lhs, rhs, tracing_data_);
}

bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs) const {
bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs,
Optional<ObjectPathPair> paths) const {
return CompareAttributeValues(lhs, rhs, tracing_data_);
}

bool SEqualReducer::operator()(const int& lhs, const int& rhs) const {
bool SEqualReducer::operator()(const int& lhs, const int& rhs,
Optional<ObjectPathPair> paths) const {
return CompareAttributeValues(lhs, rhs, tracing_data_);
}

bool SEqualReducer::operator()(const bool& lhs, const bool& rhs) const {
bool SEqualReducer::operator()(const bool& lhs, const bool& rhs,
Optional<ObjectPathPair> paths) const {
return CompareAttributeValues(lhs, rhs, tracing_data_);
}

bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs) const {
bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs,
Optional<ObjectPathPair> paths) const {
return CompareAttributeValues(lhs, rhs, tracing_data_);
}

bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs) const {
bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs,
Optional<ObjectPathPair> paths) const {
return CompareAttributeValues(lhs, rhs, tracing_data_);
}

bool SEqualReducer::EnumAttrsEqual(int lhs, int rhs, const void* lhs_address,
const void* rhs_address) const {
const void* rhs_address, Optional<ObjectPathPair> paths) const {
if (lhs == rhs) {
return true;
} else {
GetPathsFromAttrAddressesAndStoreMismatch(lhs_address, rhs_address, tracing_data_);
return false;
}

if (tracing_data_ && !tracing_data_->first_mismatch->defined()) {
if (paths) {
*tracing_data_->first_mismatch = paths.value();
} else {
GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data_);
}
}

return false;
}

const ObjectPathPair& SEqualReducer::GetCurrentObjectPaths() const {
Expand Down