Skip to content

Commit

Permalink
Address Lite's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gbonik committed Jun 27, 2022
1 parent fca83c0 commit 37183a2
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 33 deletions.
16 changes: 8 additions & 8 deletions include/tvm/node/object_path.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class ObjectPathNode : public Object {
const ObjectPathNode* ParentNode() const;

/*! Compares just the last node of the path, without comparing the whole path. */
virtual bool LastNodeEqual(const ObjectPathNode& other) const = 0;
virtual bool LastNodeEqual(const ObjectPathNode* other) const = 0;

virtual std::string LastNodeString() const = 0;

Expand Down Expand Up @@ -130,7 +130,7 @@ class RootPathNode final : public ObjectPathNode {
TVM_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode);

protected:
bool LastNodeEqual(const ObjectPathNode& other) const final;
bool LastNodeEqual(const ObjectPathNode* other) const final;
std::string LastNodeString() const final;
};

Expand All @@ -152,7 +152,7 @@ class AttributeAccessPathNode final : public ObjectPathNode {
TVM_DECLARE_FINAL_OBJECT_INFO(AttributeAccessPathNode, ObjectPathNode);

protected:
bool LastNodeEqual(const ObjectPathNode& other) const final;
bool LastNodeEqual(const ObjectPathNode* other) const final;
std::string LastNodeString() const final;
};

Expand All @@ -171,7 +171,7 @@ class UnknownAttributeAccessPathNode final : public ObjectPathNode {
TVM_DECLARE_FINAL_OBJECT_INFO(UnknownAttributeAccessPathNode, ObjectPathNode);

protected:
bool LastNodeEqual(const ObjectPathNode& other) const final;
bool LastNodeEqual(const ObjectPathNode* other) const final;
std::string LastNodeString() const final;
};

Expand All @@ -194,7 +194,7 @@ class ArrayIndexPathNode : public ObjectPathNode {
TVM_DECLARE_FINAL_OBJECT_INFO(ArrayIndexPathNode, ObjectPathNode);

protected:
bool LastNodeEqual(const ObjectPathNode& other) const final;
bool LastNodeEqual(const ObjectPathNode* other) const final;
std::string LastNodeString() const final;
};

Expand All @@ -216,7 +216,7 @@ class MissingArrayElementPathNode : public ObjectPathNode {
TVM_DECLARE_FINAL_OBJECT_INFO(MissingArrayElementPathNode, ObjectPathNode);

protected:
bool LastNodeEqual(const ObjectPathNode& other) const final;
bool LastNodeEqual(const ObjectPathNode* other) const final;
std::string LastNodeString() const final;
};

Expand All @@ -238,7 +238,7 @@ class MapValuePathNode : public ObjectPathNode {
TVM_DECLARE_FINAL_OBJECT_INFO(MapValuePathNode, ObjectPathNode);

protected:
bool LastNodeEqual(const ObjectPathNode& other) const final;
bool LastNodeEqual(const ObjectPathNode* other) const final;
std::string LastNodeString() const final;
};

Expand All @@ -257,7 +257,7 @@ class MissingMapEntryPathNode : public ObjectPathNode {
TVM_DECLARE_FINAL_OBJECT_INFO(MissingMapEntryPathNode, ObjectPathNode);

protected:
bool LastNodeEqual(const ObjectPathNode& other) const final;
bool LastNodeEqual(const ObjectPathNode* other) const final;
std::string LastNodeString() const final;
};

Expand Down
32 changes: 16 additions & 16 deletions src/node/object_path.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ bool ObjectPath::PathsEqual(const ObjectPath& other) const {
if (lhs->type_index() != rhs->type_index()) {
return false;
}
if (!lhs->LastNodeEqual(*rhs)) {
if (!lhs->LastNodeEqual(rhs)) {
return false;
}
lhs = lhs->ParentNode();
Expand Down Expand Up @@ -171,7 +171,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<ObjectPathNode>(PrintObj

RootPathNode::RootPathNode() : ObjectPathNode(nullptr) {}

bool RootPathNode::LastNodeEqual(const ObjectPathNode& other) const { return true; }
bool RootPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; }

std::string RootPathNode::LastNodeString() const { return "<root>"; }

Expand All @@ -182,9 +182,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<RootPathNode>(PrintObjec
AttributeAccessPathNode::AttributeAccessPathNode(ObjectPathNode* parent, String attr_key)
: ObjectPathNode(parent), attr_key(std::move(attr_key)) {}

bool AttributeAccessPathNode::LastNodeEqual(const ObjectPathNode& other) const {
const auto& otherAttrAccess = static_cast<const AttributeAccessPathNode&>(other);
return attr_key == otherAttrAccess.attr_key;
bool AttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const {
const auto* otherAttrAccess = static_cast<const AttributeAccessPathNode*>(other);
return attr_key == otherAttrAccess->attr_key;
}

std::string AttributeAccessPathNode::LastNodeString() const { return "." + attr_key; }
Expand All @@ -197,7 +197,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
UnknownAttributeAccessPathNode::UnknownAttributeAccessPathNode(ObjectPathNode* parent)
: ObjectPathNode(parent) {}

bool UnknownAttributeAccessPathNode::LastNodeEqual(const ObjectPathNode& other) const {
bool UnknownAttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const {
// Consider any two unknown attribute accesses unequal
return false;
}
Expand All @@ -214,9 +214,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
ArrayIndexPathNode::ArrayIndexPathNode(ObjectPathNode* parent, size_t index)
: ObjectPathNode(parent), index(index) {}

bool ArrayIndexPathNode::LastNodeEqual(const ObjectPathNode& other) const {
const auto& otherArrayIndex = static_cast<const ArrayIndexPathNode&>(other);
return index == otherArrayIndex.index;
bool ArrayIndexPathNode::LastNodeEqual(const ObjectPathNode* other) const {
const auto* otherArrayIndex = static_cast<const ArrayIndexPathNode*>(other);
return index == otherArrayIndex->index;
}

std::string ArrayIndexPathNode::LastNodeString() const { return "[" + std::to_string(index) + "]"; }
Expand All @@ -228,9 +228,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<ArrayIndexPathNode>(Prin
MissingArrayElementPathNode::MissingArrayElementPathNode(ObjectPathNode* parent, size_t index)
: ObjectPathNode(parent), index(index) {}

bool MissingArrayElementPathNode::LastNodeEqual(const ObjectPathNode& other) const {
const auto& otherMissingElement = static_cast<const MissingArrayElementPathNode&>(other);
return index == otherMissingElement.index;
bool MissingArrayElementPathNode::LastNodeEqual(const ObjectPathNode* other) const {
const auto* otherMissingElement = static_cast<const MissingArrayElementPathNode*>(other);
return index == otherMissingElement->index;
}

std::string MissingArrayElementPathNode::LastNodeString() const {
Expand All @@ -245,9 +245,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
MapValuePathNode::MapValuePathNode(ObjectPathNode* parent, ObjectRef key)
: ObjectPathNode(parent), key(std::move(key)) {}

bool MapValuePathNode::LastNodeEqual(const ObjectPathNode& other) const {
const auto& otherMapValue = static_cast<const MapValuePathNode&>(other);
return ObjectEqual()(key, otherMapValue.key);
bool MapValuePathNode::LastNodeEqual(const ObjectPathNode* other) const {
const auto* otherMapValue = static_cast<const MapValuePathNode*>(other);
return ObjectEqual()(key, otherMapValue->key);
}

std::string MapValuePathNode::LastNodeString() const {
Expand All @@ -262,7 +262,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<MapValuePathNode>(PrintO

MissingMapEntryPathNode::MissingMapEntryPathNode(ObjectPathNode* parent) : ObjectPathNode(parent) {}

bool MissingMapEntryPathNode::LastNodeEqual(const ObjectPathNode& other) const { return true; }
bool MissingMapEntryPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; }

std::string MissingMapEntryPathNode::LastNodeString() const { return "[<missing entry>]"; }

Expand Down
21 changes: 12 additions & 9 deletions src/node/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs,
class RemapVarSEqualHandler : public SEqualReducer::Handler {
public:
explicit RemapVarSEqualHandler(bool assert_mode, FirstMismatch* first_mismatch)
: assert_mode_(assert_mode), first_mismatch(first_mismatch) {}
: assert_mode_(assert_mode), first_mismatch_(first_mismatch) {}

bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const ObjectPathPair& current_paths) final {
Expand Down Expand Up @@ -259,7 +259,7 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
equal_map_rhs_.clear();

ObjectPathPair current_paths;
if (first_mismatch != nullptr) {
if (IsPathTracingEnabled()) {
current_paths.lhs_path = current_paths.rhs_path = ObjectPath::Root();
}
if (!SEqualReduce(lhs, rhs, map_free_vars, current_paths)) {
Expand All @@ -277,8 +277,8 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
// Check the result.
bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs,
const ObjectPathPair& current_paths) {
if (first_mismatch != nullptr && !result) {
first_mismatch->MaybeStoreMismatch(current_paths);
if (IsPathTracingEnabled() && !result) {
first_mismatch_->MaybeStoreMismatch(current_paths);
}
if (assert_mode_ && !result) {
LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl
Expand All @@ -299,8 +299,8 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
auto& entry = task_stack_.back();

if (entry.force_fail) {
if (first_mismatch != nullptr) {
first_mismatch->MaybeStoreMismatch(entry.current_paths);
if (IsPathTracingEnabled()) {
first_mismatch_->MaybeStoreMismatch(entry.current_paths);
}
return false;
}
Expand Down Expand Up @@ -356,11 +356,11 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
if (equal_map_rhs_.count(rhs)) return false;

// Run reduce check for free nodes.
if (first_mismatch == nullptr) {
if (!IsPathTracingEnabled()) {
return vtable_->SEqualReduce(lhs.get(), rhs.get(),
SEqualReducer(this, nullptr, map_free_vars));
} else {
PathTracingData tracing_data = {current_paths, lhs, rhs, first_mismatch};
PathTracingData tracing_data = {current_paths, lhs, rhs, first_mismatch_};
return vtable_->SEqualReduce(lhs.get(), rhs.get(),
SEqualReducer(this, &tracing_data, map_free_vars));
}
Expand Down Expand Up @@ -394,6 +394,9 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
Task(ForceFailTag, const ObjectPathPair& current_paths)
: current_paths(current_paths), force_fail(true) {}
};

bool IsPathTracingEnabled() const { return first_mismatch_ != nullptr; }

// list of pending tasks to be pushed to the stack.
std::vector<Task> pending_tasks_;
// Internal task stack to executed the task.
Expand All @@ -403,7 +406,7 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
// If in assert mode, must return true, and will throw error otherwise.
bool assert_mode_{false};
// Location to store the paths to the first detected mismatch, or nullptr to disable path tracing.
FirstMismatch* first_mismatch;
FirstMismatch* first_mismatch_;
// reflection vtable
ReflectionVTable* vtable_ = ReflectionVTable::Global();
// map from lhs to rhs
Expand Down

0 comments on commit 37183a2

Please sign in to comment.