Skip to content

Commit

Permalink
[TVMScript] Add object path tracing to StructuralEqual
Browse files Browse the repository at this point in the history
Motivation: when two IR objects fail a structural equality check, currently there is no easy way to
find out which part of the IR caused the mismatch. In this PR, we modify the `StructuralEqual`
infrastructure to also optionally return a pair of `ObjectPath` objects that point to the mismatch.
(See apache#11977). In the upcoming PRs, we will pass these paths to the
TIR printer, so that it could highlight the mismatch location nicely.

Tracking issue: apache#11912
  • Loading branch information
gbonik authored and junrushao committed Jul 27, 2022
1 parent 98d5feb commit 4c28f15
Show file tree
Hide file tree
Showing 10 changed files with 932 additions and 42 deletions.
6 changes: 6 additions & 0 deletions include/tvm/node/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,5 +404,11 @@ inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr
}
}

/*!
* \brief Given an object and an address of its attribute, return the key of the attribute.
* \return nullptr if no attribute with the given address exists.
*/
Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address);

} // namespace tvm
#endif // TVM_NODE_REFLECTION_H_
157 changes: 140 additions & 17 deletions include/tvm/node/structural_equal.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#define TVM_NODE_STRUCTURAL_EQUAL_H_

#include <tvm/node/functor.h>
#include <tvm/node/object_path.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/data_type.h>

Expand Down Expand Up @@ -56,6 +57,31 @@ class BaseValueEqual {
}
};

/*!
* \brief Pair of `ObjectPath`s, one for each object being tested for structural equality.
*/
struct ObjectPathPair {
ObjectPath lhs_path;
ObjectPath rhs_path;
};

// Could be replaced with std::optional<ObjectPathPair>
class OptionalObjectPathPair {
public:
OptionalObjectPathPair() = default;

OptionalObjectPathPair(const ObjectPathPair& p) // NOLINT(runtime/explicit)
: lhs_path(p.lhs_path), rhs_path(p.rhs_path) {}

bool defined() const { return lhs_path.defined(); }

ObjectPathPair value() const { return {lhs_path.value(), rhs_path.value()}; }

private:
Optional<ObjectPath> lhs_path;
Optional<ObjectPath> rhs_path;
};

/*!
* \brief Content-aware structural equality comparator for objects.
*
Expand Down Expand Up @@ -99,7 +125,10 @@ class StructuralEqual : public BaseValueEqual {
* equality checking. Instead, it can store the necessary equality conditions
* and check later via an internally managed stack.
*/
class SEqualReducer : public BaseValueEqual {
class SEqualReducer {
private:
struct PathTracingData;

public:
/*! \brief Internal handler that defines custom behaviors.. */
class Handler {
Expand All @@ -110,12 +139,24 @@ class SEqualReducer : public BaseValueEqual {
* \param lhs The left operand.
* \param rhs The right operand.
* \param map_free_vars Whether do we allow remap variables if possible.
* \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability.
*
* \return false if there is an immediate failure, true otherwise.
* \note This function may save the equality condition of (lhs == rhs) in an internal
* stack and try to resolve later.
*/
virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) = 0;
virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const OptionalObjectPathPair& current_paths) = 0;

/*!
* \brief Mark the comparison as failed, but don't fail immediately.
*
* This is useful for producing better error messages when comparing containers.
* For example, if two array sizes mismatch, it's better to mark the comparison as failed
* but compare array elements anyway, so that we could find the true first mismatch.
*/
virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0;

/*!
* \brief Lookup the graph node equal map for vars that are already mapped.
*
Expand All @@ -129,28 +170,72 @@ class SEqualReducer : public BaseValueEqual {
* \brief Mark current comparison as graph node equal comparison.
*/
virtual void MarkGraphNode() = 0;
};

using BaseValueEqual::operator();
protected:
using PathTracingData = SEqualReducer::PathTracingData;
};

/*! \brief default constructor */
SEqualReducer() = default;
/*!
* \brief Constructor with a specific handler.
* \param handler The equal handler for objects.
* \param tracing_data Optional pointer to the path tracing data.
* \param map_free_vars Whether or not to map free variables.
*/
explicit SEqualReducer(Handler* handler, bool map_free_vars)
: handler_(handler), map_free_vars_(map_free_vars) {}
explicit SEqualReducer(Handler* handler, const PathTracingData* tracing_data, bool map_free_vars)
: handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {}

/*!
* \brief Reduce condition to comparison of two attribute values.
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
bool operator()(const double& lhs, const double& rhs) const;
bool operator()(const int64_t& lhs, const int64_t& rhs) const;
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const;
bool operator()(const int& lhs, const int& rhs) const;
bool operator()(const bool& lhs, const bool& rhs) const;
bool operator()(const std::string& lhs, const std::string& rhs) const;
bool operator()(const DataType& lhs, const DataType& rhs) const;

template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
bool operator()(const ENum& lhs, const ENum& rhs) const {
using Underlying = typename std::underlying_type<ENum>::type;
static_assert(std::is_same<Underlying, int>::value,
"Enum must have `int` as the underlying type");
return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs);
}

/*!
* \brief Reduce condition to comparison of two objects.
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
return handler_->SEqualReduce(lhs, rhs, map_free_vars_);
bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;

/*!
* \brief Reduce condition to comparison of two objects.
*
* Like `operator()`, but with an additional `paths` parameter that specifies explicit object
* paths for `lhs` and `rhs`. This is useful for implementing SEqualReduce() methods for container
* objects like Array and Map, or other custom objects that store nested objects that are not
* simply attributes.
*
* Can only be called when `IsPathTracingEnabled()` is `true`.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \param paths Object paths for `lhs` and `rhs`.
* \return the immediate check result.
*/
bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const {
ICHECK(IsPathTracingEnabled()) << "Path tracing must be enabled when calling this function";
return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths);
}

/*!
* \brief Reduce condition to comparison of two definitions,
* where free vars can be mapped.
Expand All @@ -162,9 +247,8 @@ class SEqualReducer : public BaseValueEqual {
* \param rhs The right operand.
* \return the immediate check result.
*/
bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
return handler_->SEqualReduce(lhs, rhs, true);
}
bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs);

/*!
* \brief Reduce condition to comparison of two arrays.
* \param lhs The left operand.
Expand All @@ -173,13 +257,20 @@ class SEqualReducer : public BaseValueEqual {
*/
template <typename T>
bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
// quick specialization for Array to reduce amount of recursion
// depth as array comparison is pretty common.
if (lhs.size() != rhs.size()) return false;
for (size_t i = 0; i < lhs.size(); ++i) {
if (!(operator()(lhs[i], rhs[i]))) return false;
if (tracing_data_ == nullptr) {
// quick specialization for Array to reduce amount of recursion
// depth as array comparison is pretty common.
if (lhs.size() != rhs.size()) return false;
for (size_t i = 0; i < lhs.size(); ++i) {
if (!(operator()(lhs[i], rhs[i]))) return false;
}
return true;
}
return true;

// If tracing is enabled, fall back to the regular path
const ObjectRef& lhs_obj = lhs;
const ObjectRef& rhs_obj = rhs;
return (*this)(lhs_obj, rhs_obj);
}
/*!
* \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
Expand All @@ -198,9 +289,41 @@ class SEqualReducer : public BaseValueEqual {
/*! \return Get the internal handler. */
Handler* operator->() const { return handler_; }

/*! \brief Check if this reducer is tracing paths to the first mismatch. */
bool IsPathTracingEnabled() const { return tracing_data_ != nullptr; }

/*!
* \brief Get the paths of the currently compared objects.
*
* Can only be called when `IsPathTracingEnabled()` is true.
*/
const ObjectPathPair& GetCurrentObjectPaths() const;

/*!
* \brief Specify the object paths of a detected mismatch.
*
* Can only be called when `IsPathTracingEnabled()` is true.
*/
void RecordMismatchPaths(const ObjectPathPair& paths) const;

private:
bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const;

bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const ObjectPathPair* paths) const;

static void GetPathsFromAttrAddressesAndStoreMismatch(const void* lhs_address,
const void* rhs_address,
const PathTracingData* tracing_data);

template <typename T>
static bool CompareAttributeValues(const T& lhs, const T& rhs,
const PathTracingData* tracing_data);

/*! \brief Internal class pointer. */
Handler* handler_;
/*! \brief Pointer to the current path tracing context, or nullptr if path tracing is disabled. */
const PathTracingData* tracing_data_;
/*! \brief Whether or not to map free vars. */
bool map_free_vars_;
};
Expand Down
35 changes: 33 additions & 2 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def structural_equal(lhs, rhs, map_free_vars=False):
The left operand.
map_free_vars : bool
Whether or not shall we map free vars that does
not bound to any definitions as equal to each other.
Whether free variables (i.e. variables without a definition site) should be mapped
as equal to each other.
Return
------
Expand All @@ -209,6 +209,37 @@ def structural_equal(lhs, rhs, map_free_vars=False):
return bool(tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars))


def get_first_structural_mismatch(lhs, rhs, map_free_vars=False):
"""Like structural_equal(), but returns the ObjectPaths of the first detected mismatch.
Parameters
----------
lhs : Object
The left operand.
rhs : Object
The left operand.
map_free_vars : bool
Whether free variables (i.e. variables without a definition site) should be mapped
as equal to each other.
Returns
-------
mismatch: Optional[Tuple[ObjectPath, ObjectPath]]
`None` if `lhs` and `rhs` are structurally equal.
Otherwise, a tuple of two ObjectPath objects that point to the first detected mismtach.
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
mismatch = tvm.runtime._ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars)
if len(mismatch) == 0:
return None
else:
assert len(mismatch) == 2
return tuple(mismatch)


def assert_structural_equal(lhs, rhs, map_free_vars=False):
"""Assert lhs and rhs are structurally equal to each other.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# class exposures
from .packed_func import PackedFunc
from .object import Object
from .object_path import ObjectPath
from .object_generic import ObjectGeneric, ObjectTypes
from .ndarray import NDArray, DataType, DataTypeCode, Device
from .module import Module, num_threads
Expand Down
44 changes: 44 additions & 0 deletions src/node/reflection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,48 @@ TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr);
TVM_REGISTER_GLOBAL("node.NodeListAttrNames").set_body(NodeListAttrNames);

TVM_REGISTER_GLOBAL("node.MakeNode").set_body(MakeNode);

namespace {
// Attribute visitor class for finding the attribute key by its address
class GetAttrKeyByAddressVisitor : public AttrVisitor {
public:
explicit GetAttrKeyByAddressVisitor(const void* attr_address)
: attr_address_(attr_address), key_(nullptr) {}

void Visit(const char* key, double* value) final { DoVisit(key, value); }
void Visit(const char* key, int64_t* value) final { DoVisit(key, value); }
void Visit(const char* key, uint64_t* value) final { DoVisit(key, value); }
void Visit(const char* key, int* value) final { DoVisit(key, value); }
void Visit(const char* key, bool* value) final { DoVisit(key, value); }
void Visit(const char* key, std::string* value) final { DoVisit(key, value); }
void Visit(const char* key, void** value) final { DoVisit(key, value); }
void Visit(const char* key, DataType* value) final { DoVisit(key, value); }
void Visit(const char* key, runtime::NDArray* value) final { DoVisit(key, value); }
void Visit(const char* key, runtime::ObjectRef* value) final { DoVisit(key, value); }

const char* GetKey() const { return key_; }

private:
const void* attr_address_;
const char* key_;

void DoVisit(const char* key, const void* candidate) {
if (attr_address_ == candidate) {
key_ = key;
}
}
};
} // anonymous namespace

Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address) {
GetAttrKeyByAddressVisitor visitor(attr_address);
ReflectionVTable::Global()->VisitAttrs(const_cast<Object*>(object), &visitor);
const char* key = visitor.GetKey();
if (key == nullptr) {
return NullOpt;
} else {
return String(key);
}
}

} // namespace tvm
Loading

0 comments on commit 4c28f15

Please sign in to comment.