diff --git a/include/tvm/script/printer/traced_object.h b/include/tvm/script/printer/traced_object.h new file mode 100644 index 000000000000..6f04b66cec97 --- /dev/null +++ b/include/tvm/script/printer/traced_object.h @@ -0,0 +1,482 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/script/printer/traced_object.h + * Wrappers around TVM objects that also store an ObjectPath from some "root" object + * to the wrapper object. + */ + +#ifndef TVM_SCRIPT_PRINTER_TRACED_OBJECT_H_ +#define TVM_SCRIPT_PRINTER_TRACED_OBJECT_H_ + +#include +#include +#include + +#include +#include + +namespace tvm { + +template +class TracedObject; +template +class TracedMap; +template +class TracedArray; +template +class TracedOptional; +template +class TracedBasicValue; + +namespace detail { + +template ::value> +struct TracedObjectWrapperSelector; + +template +struct TracedObjectWrapperSelector { + using Type = TracedBasicValue; +}; + +template +struct TracedObjectWrapperSelector { + using Type = TracedObject; +}; + +template +struct TracedObjectWrapperSelector, true> { + using Type = TracedMap; +}; + +template +struct TracedObjectWrapperSelector, true> { + using Type = TracedArray; +}; + +template +struct TracedObjectWrapperSelector, true> { + using Type = TracedOptional; +}; + +} // namespace detail + +/*! + * \brief Traced wrapper for regular (non-container) TVM objects. + */ +template +class TracedObject { + using ObjectType = typename RefT::ContainerType; + + public: + // Don't use this direcly. For convenience, call MakeTraced() instead. + explicit TracedObject(const RefT& object_ref, ObjectPath path) + : ref_(object_ref), path_(std::move(path)) {} + + // Implicit conversion from a derived reference class + template + TracedObject(const TracedObject& derived) + : ref_(derived.Get()), path_(derived.GetPath()) {} + + /*! + * \brief Get a traced wrapper for an attribute of the wrapped object. + */ + template + typename detail::TracedObjectWrapperSelector::Type GetAttr(T BaseType::*member_ptr) const { + using WrapperType = typename detail::TracedObjectWrapperSelector::Type; + const ObjectType* node = static_cast(ref_.get()); + const T& attr = node->*member_ptr; + Optional attr_key = ICHECK_NOTNULL(GetAttrKeyByAddress(node, &attr)); + return WrapperType(attr, path_->Attr(attr_key)); + } + + /*! + * \brief Access the wrapped object. + */ + const RefT& Get() const { return ref_; } + + /*! + * \brief Check if the reference to the wrapped object can be converted to `RefU`. + */ + template + bool IsInstance() const { + return ref_->template IsInstance(); + } + + /*! + * \brief Same as Get().defined(). + */ + bool defined() const { return ref_.defined(); } + + /*! + * \brief Convert the wrapped reference type to a subtype. + * + * Throws an exception if IsInstance() is false. + */ + template + TracedObject Downcast() const { + return TracedObject(tvm::runtime::Downcast(ref_), path_); + } + + /*! + * \brief Convert the wrapped reference type to a subtype. + * + * Returns an empty optional if IsInstance() is false. + */ + template + TracedOptional TryDowncast() const { + if (ref_->template IsInstance()) { + return Downcast(); + } else { + return TracedOptional(NullOpt, path_); + } + } + + /*! + * \brief Get the path of the wrapped object. + */ + const ObjectPath& GetPath() const { return path_; } + + private: + RefT ref_; + ObjectPath path_; +}; + +/*! + * \brief Iterator class for TracedMap + */ +template +class TracedMapIterator { + public: + using WrappedV = typename detail::TracedObjectWrapperSelector::Type; + using MapIter = typename Map::iterator; + + using iterator_category = std::bidirectional_iterator_tag; + using difference_type = ptrdiff_t; + using value_type = const std::pair; + using pointer = value_type*; + using reference = value_type; + + explicit TracedMapIterator(MapIter iter, ObjectPath map_path) + : iter_(iter), map_path_(std::move(map_path)) {} + + bool operator==(const TracedMapIterator& other) const { return iter_ == other.iter_; } + + bool operator!=(const TracedMapIterator& other) const { return iter_ != other.iter_; } + + pointer operator->() const = delete; + + reference operator*() const { + auto kv = *iter_; + return std::make_pair(kv.first, WrappedV(kv.second, map_path_->MapValue(kv.first))); + } + + TracedMapIterator& operator++() { + ++iter_; + return *this; + } + + TracedMapIterator operator++(int) { + TracedMapIterator copy = *this; + ++(*this); + return copy; + } + + private: + MapIter iter_; + ObjectPath map_path_; +}; + +/*! + * \brief Traced wrapper for Map objects. + */ +template +class TracedMap { + public: + using WrappedV = typename detail::TracedObjectWrapperSelector::Type; + + using iterator = TracedMapIterator; + + // Don't use this direcly. For convenience, call MakeTraced() instead. + explicit TracedMap(Map map, ObjectPath path) + : map_(std::move(map)), path_(std::move(path)) {} + + /*! + * \brief Get a value by its key, wrapped in a traced wrapper. + */ + WrappedV at(const K& key) const { + auto it = map_.find(key); + ICHECK(it != map_.end()) << "No such key in Map"; + auto kv = *it; + return WrappedV(kv.second, path_->MapValue(kv.first)); + } + + /*! + * \brief Access the wrapped map object. + */ + const Map& Get() const { return map_; } + + /*! + * \brief Get the path of the wrapped object. + */ + const ObjectPath& GetPath() const { return path_; } + + /*! + * \brief Get an iterator to the first item of the map. + */ + iterator begin() const { return iterator(map_.begin(), path_); } + + /*! + * \brief Get an iterator to the end of the map. + */ + iterator end() const { return iterator(map_.end(), path_); } + + /*! + * \brief Returns true iff the wrapped map is empty. + */ + bool empty() const { return map_.empty(); } + + private: + Map map_; + ObjectPath path_; +}; + +/*! + * \brief Iterator class for TracedArray + */ +template +class TracedArrayIterator { + public: + using WrappedT = typename detail::TracedObjectWrapperSelector::Type; + + using difference_type = ptrdiff_t; + using value_type = WrappedT; + using pointer = WrappedT*; + using reference = WrappedT&; + using iterator_category = std::random_access_iterator_tag; + + explicit TracedArrayIterator(Array array, size_t index, ObjectPath array_path) + : array_(array), index_(index), array_path_(array_path) {} + + TracedArrayIterator& operator++() { + ++index_; + return *this; + } + TracedArrayIterator& operator--() { + --index_; + return *this; + } + TracedArrayIterator operator++(int) { + TracedArrayIterator copy = *this; + ++index_; + return copy; + } + TracedArrayIterator operator--(int) { + TracedArrayIterator copy = *this; + --index_; + return copy; + } + + TracedArrayIterator operator+(difference_type offset) const { + return TracedArrayIterator(array_, index_ + offset, array_path_); + } + + TracedArrayIterator operator-(difference_type offset) const { + return TracedArrayIterator(array_, index_ - offset, array_path_); + } + + difference_type operator-(const TracedArrayIterator& rhs) const { return index_ - rhs.index_; } + + bool operator==(TracedArrayIterator other) const { + return array_.get() == other.array_.get() && index_ == other.index_; + } + bool operator!=(TracedArrayIterator other) const { return !(*this == other); } + value_type operator*() const { return WrappedT(array_[index_], array_path_->ArrayIndex(index_)); } + + private: + Array array_; + size_t index_; + ObjectPath array_path_; +}; + +/*! + * \brief Traced wrapper for Array objects. + */ +template +class TracedArray { + public: + using WrappedT = typename detail::TracedObjectWrapperSelector::Type; + + using iterator = TracedArrayIterator; + + // Don't use this direcly. For convenience, call MakeTraced() instead. + explicit TracedArray(Array array, ObjectPath path) + : array_(std::move(array)), path_(std::move(path)) {} + + /*! + * \brief Access the wrapped array object. + */ + const Array& Get() const { return array_; } + + /*! + * \brief Get the path of the wrapped array object. + */ + const ObjectPath& GetPath() const { return path_; } + + /*! + * \brief Get an element by index, wrapped in a traced wrapper. + */ + WrappedT operator[](size_t index) const { + return WrappedT(array_[index], path_->ArrayIndex(index)); + } + + /*! + * \brief Get an iterator to the first array element. + * + * The iterator's dereference operator will automatically wrap each element in a traced wrapper. + */ + iterator begin() const { return iterator(array_, 0, path_); } + + /*! + * \brief Get an iterator to the end of the array. + * + * The iterator's dereference operator will automatically wrap each element in a traced wrapper. + */ + iterator end() const { return iterator(array_, array_.size(), path_); } + + /*! + * \brief Returns true iff the wrapped array is empty. + */ + bool empty() const { return array_.empty(); } + + /*! + * \brief Get the size of the wrapped array. + */ + size_t size() const { return array_.size(); } + + private: + Array array_; + ObjectPath path_; +}; + +/*! + * \brief Traced wrapper for Optional objects. + */ +template +class TracedOptional { + public: + using WrappedT = typename detail::TracedObjectWrapperSelector::Type; + + /*! + * \brief Implicit conversion from the corresponding non-optional traced wrapper. + */ + TracedOptional(const WrappedT& value) // NOLINT(runtime/explicit) + : optional_(value.Get().defined() ? value.Get() : Optional(NullOpt)), + path_(value.GetPath()) {} + + // Don't use this direcly. For convenience, call MakeTraced() instead. + explicit TracedOptional(Optional optional, ObjectPath path) + : optional_(std::move(optional)), path_(std::move(path)) {} + + /*! + * \brief Access the wrapped optional object. + */ + const Optional& Get() const { return optional_; } + + /*! + * \brief Get the path of the wrapped optional object. + */ + const ObjectPath& GetPath() const { return path_; } + + /*! + * \brief Returns true iff the object is present. + */ + bool defined() const { return optional_.defined(); } + + /*! + * \brief Returns a non-optional traced wrapper, throws if defined() is false. + */ + WrappedT value() const { return WrappedT(optional_.value(), path_); } + + /*! + * \brief Same as defined(). + */ + explicit operator bool() const { return optional_.defined(); } + + private: + Optional optional_; + ObjectPath path_; +}; + +/*! + * \brief Traced wrapper for basic values (i.e. non-TVM objects) + */ +template +class TracedBasicValue { + public: + explicit TracedBasicValue(const T& value, ObjectPath path) + : value_(value), path_(std::move(path)) {} + + /*! + * \brief Access the wrapped value. + */ + const T& Get() const { return value_; } + + /*! + * \brief Get the path of the wrapped value. + */ + const ObjectPath& GetPath() const { return path_; } + + /*! + * \brief Transform the wrapped value without changing its path. + */ + template + typename detail::TracedObjectWrapperSelector::type>::Type + ApplyFunc(F&& f) const { + return MakeTraced(f(value_), path_); + } + + private: + T value_; + ObjectPath path_; +}; + +/*! + * \brief Wrap the given root object in an appropriate traced wrapper class. + */ +template +typename detail::TracedObjectWrapperSelector::Type MakeTraced(const RefT& object) { + using WrappedT = typename detail::TracedObjectWrapperSelector::Type; + return WrappedT(object, ObjectPath::Root()); +} + +/*! + * \brief Wrap the given object with the given path in an appropriate traced wrapper class. + */ +template +typename detail::TracedObjectWrapperSelector::Type MakeTraced(const RefT& object, + ObjectPath path) { + using WrappedT = typename detail::TracedObjectWrapperSelector::Type; + return WrappedT(object, std::move(path)); +} + +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_TRACED_OBJECT_H_ diff --git a/tests/cpp/traced_object_test.cc b/tests/cpp/traced_object_test.cc new file mode 100644 index 000000000000..7890a67eef95 --- /dev/null +++ b/tests/cpp/traced_object_test.cc @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +using namespace tvm; + +namespace { + +class DummyObjectNode : public Object { + public: + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const char* _type_key = "TracedObjectTestDummyObject"; + TVM_DECLARE_FINAL_OBJECT_INFO(DummyObjectNode, Object); +}; + +class DummyObject : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DummyObject, ObjectRef, DummyObjectNode); +}; + +TVM_REGISTER_NODE_TYPE(DummyObjectNode); + +class ObjectWithAttrsNode : public Object { + public: + int64_t int64_attr = 5; + Map map_attr; + Array array_attr; + DummyObject obj_attr; + + ObjectWithAttrsNode() : obj_attr(make_object()) {} + + void VisitAttrs(AttrVisitor* v) { + v->Visit("int64_attr", &int64_attr); + v->Visit("map_attr", &map_attr); + v->Visit("array_attr", &array_attr); + v->Visit("obj_attr", &obj_attr); + } + + static constexpr const char* _type_key = "TracedObjectTestObjectWithAttrs"; + TVM_DECLARE_FINAL_OBJECT_INFO(ObjectWithAttrsNode, Object); +}; + +class ObjectWithAttrs : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ObjectWithAttrs, ObjectRef, ObjectWithAttrsNode); +}; + +TVM_REGISTER_NODE_TYPE(ObjectWithAttrsNode); + +} // anonymous namespace + +TEST(TracedObjectTest, MakeTraced_RootObject) { + ObjectWithAttrs root(make_object()); + auto root_traced = MakeTraced(root); + + static_assert(std::is_same>::value); + ICHECK(root_traced.GetPath()->PathsEqual(ObjectPath::Root())); + ICHECK_EQ(root_traced.Get().get(), root.get()); +} + +TEST(TracedObjectTest, MakeTraced_WithPath) { + ObjectWithAttrs obj(make_object()); + auto traced = MakeTraced(obj, ObjectPath::Root()->Attr("foo")); + + static_assert(std::is_same>::value); + ICHECK(traced.GetPath()->PathsEqual(ObjectPath::Root()->Attr("foo"))); + ICHECK_EQ(traced.Get().get(), obj.get()); +} + +TEST(TracedObjectTest, TracedObject_ImplicitConversionFromDerived) { + DummyObject obj(make_object()); + auto traced = MakeTraced(obj); + static_assert(std::is_same>::value); + + // Check that TracedObject is implicitly converted to TracedObject + auto base_traced = [](const TracedObject& base) { return base; }(traced); + + static_assert(std::is_same>::value); +} + +TEST(TracedObjectTest, TracedObject_GetAttr_ObjectRef) { + ObjectWithAttrs root(make_object()); + auto root_traced = MakeTraced(root); + auto obj_attr = root_traced.GetAttr(&ObjectWithAttrsNode::obj_attr); + static_assert(std::is_same>::value); + ICHECK(obj_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("obj_attr"))); + ICHECK_EQ(obj_attr.Get().get(), root->obj_attr.get()); +} + +TEST(TracedObjectTest, TracedObject_GetAttr_Map) { + ObjectWithAttrs root(make_object()); + root->map_attr.Set("foo", "bar"); + + auto root_traced = MakeTraced(root); + auto map_attr = root_traced.GetAttr(&ObjectWithAttrsNode::map_attr); + static_assert(std::is_same>::value); + ICHECK(map_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("map_attr"))); + ICHECK_EQ(map_attr.Get().get(), root->map_attr.get()); + + auto map_val = map_attr.at("foo"); + ICHECK_EQ(map_val.Get(), "bar"); + ICHECK( + map_val.GetPath()->PathsEqual(ObjectPath::Root()->Attr("map_attr")->MapValue(String("foo")))); +} + +TEST(TracedObjectTest, TracedObject_GetAttr_Array) { + ObjectWithAttrs root(make_object()); + root->array_attr.push_back("foo"); + root->array_attr.push_back("bar"); + + auto root_traced = MakeTraced(root); + auto array_attr = root_traced.GetAttr(&ObjectWithAttrsNode::array_attr); + static_assert(std::is_same>::value); + ICHECK(array_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("array_attr"))); + ICHECK_EQ(array_attr.Get().get(), root->array_attr.get()); + + auto array_val = array_attr[1]; + ICHECK_EQ(array_val.Get(), "bar"); + ICHECK(array_val.GetPath()->PathsEqual(ObjectPath::Root()->Attr("array_attr")->ArrayIndex(1))); +} + +TEST(TracedObjectTest, TracedObject_GetAttr_Int64) { + ObjectWithAttrs root(make_object()); + auto root_traced = MakeTraced(root); + + auto int64_attr = root_traced.GetAttr(&ObjectWithAttrsNode::int64_attr); + static_assert(std::is_same>::value); + ICHECK_EQ(int64_attr.Get(), 5); + ICHECK(int64_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("int64_attr"))); +} + +TEST(TracedObjectTest, TracedObject_IsInstance) { + ObjectRef dummy(make_object()); + auto traced = MakeTraced(dummy); + ICHECK(traced.IsInstance()); + ICHECK(!traced.IsInstance()); +} + +TEST(TracedObjectTest, TracedObject_Downcast) { + ObjectRef root(make_object()); + auto traced = MakeTraced(root); + + auto as_dummy = traced.Downcast(); + static_assert(std::is_same>::value); + ICHECK_EQ(as_dummy.Get(), root); + + // Try downcasting to a wrong type + bool caught = false; + try { + traced.Downcast(); + } catch (std::exception& e) { + caught = strstr(e.what(), + "Downcast from TracedObjectTestDummyObject to TracedObjectTestObjectWithAttrs " + "failed") != nullptr; + } + ICHECK(caught); +} + +TEST(TracedObjectTest, TracedObject_TryDowncast) { + ObjectRef root(make_object()); + auto traced = MakeTraced(root); + + auto as_dummy = traced.TryDowncast(); + static_assert(std::is_same>::value); + ICHECK(as_dummy.defined()); + ICHECK_EQ(as_dummy.value().Get(), root); + + // Try downcasting to a wrong type + ICHECK(!traced.TryDowncast().defined()); +} + +TEST(TracedObjectTest, TracedMap_At) { + Map m({{"k1", "foo"}, {"k2", "bar"}}); + auto traced = MakeTraced(m); + + auto traced_foo = traced.at("k1"); + static_assert(std::is_same>::value); + ICHECK_EQ(traced_foo.Get(), "foo"); + ICHECK(traced_foo.GetPath()->PathsEqual(ObjectPath::Root()->MapValue(String("k1")))); +} + +TEST(TracedObjectTest, TracedMap_Iterator) { + Map m({{"k1", "foo"}, {"k2", "bar"}}); + auto traced = MakeTraced(m); + + size_t k1_count = 0; + size_t k2_count = 0; + + for (const auto& kv : traced) { + if (kv.first == "k1") { + ++k1_count; + ICHECK_EQ(kv.second.Get(), "foo"); + ICHECK(kv.second.GetPath()->PathsEqual(ObjectPath::Root()->MapValue(String("k1")))); + } else if (kv.first == "k2") { + ++k2_count; + ICHECK_EQ(kv.second.Get(), "bar"); + ICHECK(kv.second.GetPath()->PathsEqual(ObjectPath::Root()->MapValue(String("k2")))); + } else { + ICHECK(false); + } + } + + ICHECK_EQ(k1_count, 1); + ICHECK_EQ(k2_count, 1); +} + +TEST(TracedObjectTest, TracedArray_Index) { + Array a = {"foo", "bar"}; + auto traced = MakeTraced(a); + + auto traced_bar = traced[1]; + static_assert(std::is_same>::value); + ICHECK_EQ(traced_bar.Get(), "bar"); + ICHECK(traced_bar.GetPath()->PathsEqual(ObjectPath::Root()->ArrayIndex(1))); +} + +TEST(TracedObjectTest, TracedArray_Iterator) { + Array a = {"foo", "bar"}; + auto traced = MakeTraced(a); + + size_t index = 0; + for (const auto& x : traced) { + if (index == 0) { + ICHECK_EQ(x.Get(), "foo"); + ICHECK(x.GetPath()->PathsEqual(ObjectPath::Root()->ArrayIndex(0))); + } else if (index == 1) { + ICHECK_EQ(x.Get(), "bar"); + ICHECK(x.GetPath()->PathsEqual(ObjectPath::Root()->ArrayIndex(1))); + } else { + ICHECK(false); + } + ++index; + } + + ICHECK_EQ(index, 2); +} + +TEST(TracedObjectTest, TracedBasicValue_ApplyFunc) { + auto traced = MakeTraced(123, ObjectPath::Root()->Attr("foo")); + static_assert(std::is_same>::value); + + auto transformed = traced.ApplyFunc([](int x) { return x + 4.0; }); + static_assert(std::is_same>::value); + + ICHECK(transformed.GetPath()->PathsEqual(ObjectPath::Root()->Attr("foo"))); +}