From a9c610f1c0545753b699c7dbd2f5f2ca2568a16c Mon Sep 17 00:00:00 2001 From: Greg Bonik Date: Wed, 13 Jul 2022 16:03:11 -0700 Subject: [PATCH] [TVMScript] Add ObjectPath class (#11977) Motivation: Same IR node object can be referenced in several different contexts inside a larger IR object. For example, a variable could be referenced in several statements within a block. This makes it impossible to use an object pointer to uniquely identify a "location" within the larger IR object for error reporting purposes. The `ObjectPath` class addresses this problem by serving as a unique "locator". Tracking issue: https://github.com/apache/tvm/issues/11912 --- include/tvm/node/object_path.h | 282 ++++++++++++++++++++ python/tvm/runtime/object_path.py | 124 +++++++++ src/node/object_path.cc | 310 ++++++++++++++++++++++ tests/python/unittest/test_object_path.py | 149 +++++++++++ 4 files changed, 865 insertions(+) create mode 100644 include/tvm/node/object_path.h create mode 100644 python/tvm/runtime/object_path.py create mode 100644 src/node/object_path.cc create mode 100644 tests/python/unittest/test_object_path.py diff --git a/include/tvm/node/object_path.h b/include/tvm/node/object_path.h new file mode 100644 index 000000000000..5175c5b0c40d --- /dev/null +++ b/include/tvm/node/object_path.h @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/node/object_path.h + * ObjectPath class that represents a path from a root object to one of its descendants + * via attribute access, array indexing etc. + */ + +#ifndef TVM_NODE_OBJECT_PATH_H_ +#define TVM_NODE_OBJECT_PATH_H_ + +#include +#include +#include + +#include + +namespace tvm { + +using runtime::Object; +using runtime::ObjectPtr; +using runtime::ObjectRef; + +class ObjectPath; + +/*! + * \brief Path to an object from some root object. + * + * Motivation: + * + * Same IR node object can be referenced in several different contexts inside a larger IR object. + * For example, a variable could be referenced in several statements within a block. + * + * This makes it impossible to use an object pointer to uniquely identify a "location" within + * the larger IR object for error reporting purposes. The ObjectPath class addresses this problem + * by serving as a unique "locator". + */ +class ObjectPathNode : public Object { + public: + /*! \brief Get the parent path */ + Optional GetParent() const; + /*! + * \brief Get the length of the path. + * + * For example, the path returned by `ObjectPath::Root()` has length 1. + */ + int32_t Length() const; + + /*! + * \brief Get a path prefix of the given length. + * + * Provided `length` must not exceed the `Length()` of this path. + */ + ObjectPath GetPrefix(int32_t length) const; + + /*! + * \brief Check if this path is a prefix of another path. + * + * The prefix is not strict, i.e. a path is considered a prefix of itself. + */ + bool IsPrefixOf(const ObjectPath& other) const; + + /*! \brief Check if two paths are equal. */ + bool PathsEqual(const ObjectPath& other) const; + + /*! \brief Extend this path with access to an object attribute. */ + ObjectPath Attr(const char* attr_key) const; + + /*! \brief Extend this path with access to an object attribute. */ + ObjectPath Attr(Optional attr_key) const; + + /*! \brief Extend this path with access to an array element. */ + ObjectPath ArrayIndex(int32_t index) const; + + /*! \brief Extend this path with access to a missing array element. */ + ObjectPath MissingArrayElement(int32_t index) const; + + /*! \brief Extend this path with access to a map value. */ + ObjectPath MapValue(ObjectRef key) const; + + /*! \brief Extend this path with access to a missing map entry. */ + ObjectPath MissingMapEntry() const; + + static constexpr const char* _type_key = "ObjectPath"; + TVM_DECLARE_BASE_OBJECT_INFO(ObjectPathNode, Object); + + protected: + explicit ObjectPathNode(const ObjectPathNode* parent); + + friend class ObjectPath; + friend std::string GetObjectPathRepr(const ObjectPathNode* node); + + const ObjectPathNode* ParentNode() const; + + /*! Compares just the last node of the path, without comparing the whole path. */ + virtual bool LastNodeEqual(const ObjectPathNode* other) const = 0; + + virtual std::string LastNodeString() const = 0; + + private: + Optional parent_; + int32_t length_; +}; + +class ObjectPath : public ObjectRef { + public: + /*! \brief Create a path that represents the root object itself. */ + static ObjectPath Root(); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef, ObjectPathNode); +}; + +//------------------------------------------------------------------------- +//----- Concrete object path nodes ------------------------------------ +//------------------------------------------------------------------------- + +// ----- Root ----- + +class RootPathNode final : public ObjectPathNode { + public: + explicit RootPathNode(); + + static constexpr const char* _type_key = "RootPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class RootPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RootPath, ObjectPath, RootPathNode); +}; + +// ----- Attribute access ----- + +class AttributeAccessPathNode final : public ObjectPathNode { + public: + /*! \brief Name of the attribute being accessed. Must be a static string. */ + String attr_key; + + explicit AttributeAccessPathNode(const ObjectPathNode* parent, String attr_key); + + static constexpr const char* _type_key = "AttributeAccessPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttributeAccessPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class AttributeAccessPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(AttributeAccessPath, ObjectPath, AttributeAccessPathNode); +}; + +// ----- Unknown attribute access ----- + +class UnknownAttributeAccessPathNode final : public ObjectPathNode { + public: + explicit UnknownAttributeAccessPathNode(const ObjectPathNode* parent); + + static constexpr const char* _type_key = "UnknownAttributeAccessPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(UnknownAttributeAccessPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class UnknownAttributeAccessPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(UnknownAttributeAccessPath, ObjectPath, + UnknownAttributeAccessPathNode); +}; + +// ----- Array element access by index ----- + +class ArrayIndexPathNode : public ObjectPathNode { + public: + /*! \brief Index of the array element that is being accessed. */ + int32_t index; + + explicit ArrayIndexPathNode(const ObjectPathNode* parent, int32_t index); + + static constexpr const char* _type_key = "ArrayIndexPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(ArrayIndexPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class ArrayIndexPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(ArrayIndexPath, ObjectPath, ArrayIndexPathNode); +}; + +// ----- Missing array element ----- + +class MissingArrayElementPathNode : public ObjectPathNode { + public: + /*! \brief Index of the array element that is missing. */ + int32_t index; + + explicit MissingArrayElementPathNode(const ObjectPathNode* parent, int32_t index); + + static constexpr const char* _type_key = "MissingArrayElementPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(MissingArrayElementPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class MissingArrayElementPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(MissingArrayElementPath, ObjectPath, MissingArrayElementPathNode); +}; + +// ----- Map value ----- + +class MapValuePathNode : public ObjectPathNode { + public: + /*! \brief Key of the map entry that is being accessed */ + ObjectRef key; + + explicit MapValuePathNode(const ObjectPathNode* parent, ObjectRef key); + + static constexpr const char* _type_key = "MapValuePath"; + TVM_DECLARE_FINAL_OBJECT_INFO(MapValuePathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class MapValuePath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(MapValuePath, ObjectPath, MapValuePathNode); +}; + +// ----- Missing map entry ----- + +class MissingMapEntryPathNode : public ObjectPathNode { + public: + explicit MissingMapEntryPathNode(const ObjectPathNode* parent); + + static constexpr const char* _type_key = "MissingMapEntryPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(MissingMapEntryPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class MissingMapEntryPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(MissingMapEntryPath, ObjectPath, MissingMapEntryPathNode); +}; + +} // namespace tvm + +#endif // TVM_NODE_OBJECT_PATH_H_ diff --git a/python/tvm/runtime/object_path.py b/python/tvm/runtime/object_path.py new file mode 100644 index 000000000000..3eabce1f8694 --- /dev/null +++ b/python/tvm/runtime/object_path.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +ObjectPath class that represents a path from a root object to one of its descendants +via attribute access, array indexing etc. +""" + +import tvm._ffi +from tvm.runtime import Object +from . import _ffi_node_api + + +__all__ = ( + "ObjectPath", + "RootPath", + "AttributeAccessPath", + "UnknownAttributeAccessPath", + "ArrayIndexPath", + "MissingArrayElementPath", + "MapValuePath", + "MissingMapEntryPath", +) + + +@tvm._ffi.register_object("ObjectPath") +class ObjectPath(Object): + """ + Path to an object from some root object. + """ + + def __init__(self) -> None: + super().__init__() + raise ValueError( + "ObjectPath can't be initialized directly. " + "Use ObjectPath.root() to create a path to the root object" + ) + + @staticmethod + def root() -> "ObjectPath": + return _ffi_node_api.ObjectPathRoot() + + def __eq__(self, other): + return _ffi_node_api.ObjectPathEqual(self, other) + + def __ne__(self, other): + return not _ffi_node_api.ObjectPathEqual(self, other) + + @property + def parent(self) -> "ObjectPath": + return _ffi_node_api.ObjectPathGetParent(self) + + def __len__(self) -> int: + return _ffi_node_api.ObjectPathLength(self) + + def get_prefix(self, length) -> "ObjectPath": + return _ffi_node_api.ObjectPathGetPrefix(self, length) + + def is_prefix_of(self, other) -> "ObjectPath": + return _ffi_node_api.ObjectPathIsPrefixOf(self, other) + + def attr(self, attr_key) -> "ObjectPath": + return _ffi_node_api.ObjectPathAttr(self, attr_key) + + def array_index(self, index) -> "ObjectPath": + return _ffi_node_api.ObjectPathArrayIndex(self, index) + + def missing_array_element(self, index) -> "ObjectPath": + return _ffi_node_api.ObjectPathMissingArrayElement(self, index) + + def map_value(self, key) -> "ObjectPath": + return _ffi_node_api.ObjectPathMapValue(self, tvm.runtime.convert(key)) + + def missing_map_entry(self) -> "ObjectPath": + return _ffi_node_api.ObjectPathMissingMapEntry(self) + + +@tvm._ffi.register_object("RootPath") +class RootPath(ObjectPath): + pass + + +@tvm._ffi.register_object("AttributeAccessPath") +class AttributeAccessPath(ObjectPath): + pass + + +@tvm._ffi.register_object("UnknownAttributeAccessPath") +class UnknownAttributeAccessPath(ObjectPath): + pass + + +@tvm._ffi.register_object("ArrayIndexPath") +class ArrayIndexPath(ObjectPath): + pass + + +@tvm._ffi.register_object("MissingArrayElementPath") +class MissingArrayElementPath(ObjectPath): + pass + + +@tvm._ffi.register_object("MapValuePath") +class MapValuePath(ObjectPath): + pass + + +@tvm._ffi.register_object("MissingMapEntryPath") +class MissingMapEntryPath(ObjectPath): + pass diff --git a/src/node/object_path.cc b/src/node/object_path.cc new file mode 100644 index 000000000000..9c49daa8c376 --- /dev/null +++ b/src/node/object_path.cc @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include + +using namespace tvm::runtime; + +namespace tvm { + +// ============== ObjectPathNode ============== + +ObjectPathNode::ObjectPathNode(const ObjectPathNode* parent) + : parent_(GetRef(parent)), length_(parent == nullptr ? 1 : parent->length_ + 1) {} + +// --- GetParent --- + +Optional ObjectPathNode::GetParent() const { + if (parent_ == nullptr) { + return NullOpt; + } else { + return Downcast(parent_); + } +} + +TVM_REGISTER_GLOBAL("node.ObjectPathGetParent") + .set_body_method(&ObjectPathNode::GetParent); + +// --- Length --- + +int32_t ObjectPathNode::Length() const { return length_; } + +TVM_REGISTER_GLOBAL("node.ObjectPathLength").set_body_method(&ObjectPathNode::Length); + +// --- GetPrefix --- + +ObjectPath ObjectPathNode::GetPrefix(int32_t length) const { + CHECK_GE(length, 1) << "IndexError: Prefix length must be at least 1"; + CHECK_LE(length, Length()) << "IndexError: Attempted to get a prefix longer than the path itself"; + + const ObjectPathNode* node = this; + int32_t suffix_len = Length() - length; + for (int32_t i = 0; i < suffix_len; ++i) { + node = node->ParentNode(); + } + + return GetRef(node); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathGetPrefix") + .set_body_method(&ObjectPathNode::GetPrefix); + +// --- IsPrefixOf --- + +bool ObjectPathNode::IsPrefixOf(const ObjectPath& other) const { + int32_t this_len = Length(); + if (this_len > other->Length()) { + return false; + } + return this->PathsEqual(other->GetPrefix(this_len)); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathIsPrefixOf") + .set_body_method(&ObjectPathNode::IsPrefixOf); + +// --- Attr --- + +ObjectPath ObjectPathNode::Attr(const char* attr_key) const { + if (attr_key != nullptr) { + return ObjectPath(make_object(this, attr_key)); + } else { + return ObjectPath(make_object(this)); + } +} + +ObjectPath ObjectPathNode::Attr(Optional attr_key) const { + if (attr_key.defined()) { + return ObjectPath(make_object(this, attr_key.value())); + } else { + return ObjectPath(make_object(this)); + } +} + +TVM_REGISTER_GLOBAL("node.ObjectPathAttr") + .set_body_typed([](const ObjectPath& object_path, Optional attr_key) { + return object_path->Attr(attr_key); + }); + +// --- ArrayIndex --- + +ObjectPath ObjectPathNode::ArrayIndex(int32_t index) const { + return ObjectPath(make_object(this, index)); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathArrayIndex") + .set_body_method(&ObjectPathNode::ArrayIndex); + +// --- MissingArrayElement --- + +ObjectPath ObjectPathNode::MissingArrayElement(int32_t index) const { + return ObjectPath(make_object(this, index)); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathMissingArrayElement") + .set_body_method(&ObjectPathNode::MissingArrayElement); + +// --- MapValue --- + +ObjectPath ObjectPathNode::MapValue(ObjectRef key) const { + return ObjectPath(make_object(this, std::move(key))); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathMapValue") + .set_body_method(&ObjectPathNode::MapValue); + +// --- MissingMapEntry --- + +ObjectPath ObjectPathNode::MissingMapEntry() const { + return ObjectPath(make_object(this)); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathMissingMapEntry") + .set_body_method(&ObjectPathNode::MissingMapEntry); + +// --- PathsEqual ---- + +bool ObjectPathNode::PathsEqual(const ObjectPath& other) const { + if (!other.defined() || Length() != other->Length()) { + return false; + } + + const ObjectPathNode* lhs = this; + const ObjectPathNode* rhs = static_cast(other.get()); + + while (lhs != nullptr && rhs != nullptr) { + if (lhs->type_index() != rhs->type_index()) { + return false; + } + if (!lhs->LastNodeEqual(rhs)) { + return false; + } + lhs = lhs->ParentNode(); + rhs = rhs->ParentNode(); + } + + return lhs == nullptr && rhs == nullptr; +} + +TVM_REGISTER_GLOBAL("node.ObjectPathEqual") + .set_body_method(&ObjectPathNode::PathsEqual); + +// --- Repr --- + +std::string GetObjectPathRepr(const ObjectPathNode* node) { + std::string ret; + while (node != nullptr) { + std::string node_str = node->LastNodeString(); + ret.append(node_str.rbegin(), node_str.rend()); + node = static_cast(node->GetParent().get()); + } + std::reverse(ret.begin(), ret.end()); + return ret; +} + +static void PrintObjectPathRepr(const ObjectRef& node, ReprPrinter* p) { + p->stream << GetObjectPathRepr(static_cast(node.get())); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); + +// --- Private/protected methods --- + +const ObjectPathNode* ObjectPathNode::ParentNode() const { + return static_cast(parent_.get()); +} + +// ============== ObjectPath ============== + +/* static */ ObjectPath ObjectPath::Root() { return ObjectPath(make_object()); } + +TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root); + +// ============== Individual path classes ============== + +// ----- Root ----- + +RootPathNode::RootPathNode() : ObjectPathNode(nullptr) {} + +bool RootPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; } + +std::string RootPathNode::LastNodeString() const { return ""; } + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); + +// ----- AttributeAccess ----- + +AttributeAccessPathNode::AttributeAccessPathNode(const ObjectPathNode* parent, String attr_key) + : ObjectPathNode(parent), attr_key(std::move(attr_key)) {} + +bool AttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const { + const auto* otherAttrAccess = static_cast(other); + return attr_key == otherAttrAccess->attr_key; +} + +std::string AttributeAccessPathNode::LastNodeString() const { return "." + attr_key; } + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch(PrintObjectPathRepr); + +// ----- UnknownAttributeAccess ----- + +UnknownAttributeAccessPathNode::UnknownAttributeAccessPathNode(const ObjectPathNode* parent) + : ObjectPathNode(parent) {} + +bool UnknownAttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const { + // Consider any two unknown attribute accesses unequal + return false; +} + +std::string UnknownAttributeAccessPathNode::LastNodeString() const { + return "."; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch(PrintObjectPathRepr); + +// ----- ArrayIndexPath ----- + +ArrayIndexPathNode::ArrayIndexPathNode(const ObjectPathNode* parent, int32_t index) + : ObjectPathNode(parent), index(index) {} + +bool ArrayIndexPathNode::LastNodeEqual(const ObjectPathNode* other) const { + const auto* otherArrayIndex = static_cast(other); + return index == otherArrayIndex->index; +} + +std::string ArrayIndexPathNode::LastNodeString() const { return "[" + std::to_string(index) + "]"; } + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); + +// ----- MissingArrayElement ----- + +MissingArrayElementPathNode::MissingArrayElementPathNode(const ObjectPathNode* parent, + int32_t index) + : ObjectPathNode(parent), index(index) {} + +bool MissingArrayElementPathNode::LastNodeEqual(const ObjectPathNode* other) const { + const auto* otherMissingElement = static_cast(other); + return index == otherMissingElement->index; +} + +std::string MissingArrayElementPathNode::LastNodeString() const { + return "[]"; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch(PrintObjectPathRepr); + +// ----- MapValue ----- + +MapValuePathNode::MapValuePathNode(const ObjectPathNode* parent, ObjectRef key) + : ObjectPathNode(parent), key(std::move(key)) {} + +bool MapValuePathNode::LastNodeEqual(const ObjectPathNode* other) const { + const auto* otherMapValue = static_cast(other); + return ObjectEqual()(key, otherMapValue->key); +} + +std::string MapValuePathNode::LastNodeString() const { + std::ostringstream s; + s << "[" << key << "]"; + return s.str(); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); + +// ----- MissingMapEntry ----- + +MissingMapEntryPathNode::MissingMapEntryPathNode(const ObjectPathNode* parent) + : ObjectPathNode(parent) {} + +bool MissingMapEntryPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; } + +std::string MissingMapEntryPathNode::LastNodeString() const { return "[]"; } + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch(PrintObjectPathRepr); + +} // namespace tvm diff --git a/tests/python/unittest/test_object_path.py b/tests/python/unittest/test_object_path.py new file mode 100644 index 000000000000..f849c129df59 --- /dev/null +++ b/tests/python/unittest/test_object_path.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +from tvm.runtime import object_path +from tvm.runtime.object_path import ObjectPath + + +def test_root_path(): + root = ObjectPath.root() + assert isinstance(root, object_path.RootPath) + assert str(root) == "" + assert len(root) == 1 + assert root == ObjectPath.root() + assert root.parent is None + + +def test_path_attr(): + path = ObjectPath.root().attr("foo") + assert isinstance(path, object_path.AttributeAccessPath) + assert str(path) == ".foo" + assert len(path) == 2 + assert path.parent == ObjectPath.root() + + +def test_path_attr_unknown(): + path = ObjectPath.root().attr(None) + assert isinstance(path, object_path.UnknownAttributeAccessPath) + assert str(path) == "." + assert len(path) == 2 + assert path.parent == ObjectPath.root() + + +def test_path_array_index(): + path = ObjectPath.root().array_index(2) + assert isinstance(path, object_path.ArrayIndexPath) + assert str(path) == "[2]" + assert len(path) == 2 + assert path.parent == ObjectPath.root() + + +def test_path_missing_array_element(): + path = ObjectPath.root().missing_array_element(2) + assert isinstance(path, object_path.MissingArrayElementPath) + assert str(path) == "[]" + assert len(path) == 2 + assert path.parent == ObjectPath.root() + + +def test_path_map_value(): + path = ObjectPath.root().map_value("foo") + assert isinstance(path, object_path.MapValuePath) + assert str(path) == '["foo"]' + assert len(path) == 2 + assert path.parent == ObjectPath.root() + + +def test_path_missing_map_entry(): + path = ObjectPath.root().missing_map_entry() + assert isinstance(path, object_path.MissingMapEntryPath) + assert str(path) == "[]" + assert len(path) == 2 + assert path.parent == ObjectPath.root() + + +@pytest.mark.parametrize( + "a, b, expected", + [ + (ObjectPath.root(), ObjectPath.root(), True), + (ObjectPath.root(), ObjectPath.root().attr("foo"), True), + (ObjectPath.root().attr("foo"), ObjectPath.root(), False), + (ObjectPath.root().attr("foo"), ObjectPath.root().attr("foo"), True), + (ObjectPath.root().attr("bar"), ObjectPath.root().attr("foo"), False), + (ObjectPath.root().attr("foo"), ObjectPath.root().attr("foo").array_index(2), True), + (ObjectPath.root().attr("foo").array_index(2), ObjectPath.root().attr("foo"), False), + (ObjectPath.root().attr("foo"), ObjectPath.root().attr("bar").array_index(2), False), + ], +) +def test_path_is_prefix_of(a, b, expected): + assert a.is_prefix_of(b) == expected + + +paths_for_equality_test = [ + ObjectPath.root(), + ObjectPath.root().attr("foo"), + ObjectPath.root().attr("bar"), + ObjectPath.root().array_index(3), + ObjectPath.root().array_index(4), + ObjectPath.root().missing_array_element(3), + ObjectPath.root().missing_array_element(4), + ObjectPath.root().map_value("foo"), + ObjectPath.root().map_value("bar"), + ObjectPath.root().missing_map_entry(), + ObjectPath.root().attr("foo").missing_map_entry(), +] + + +def make_test_params_for_eq_test(): + return [ + pytest.param(idx, path, id="path{}".format(idx)) + for idx, path in enumerate(paths_for_equality_test) + ] + + +@pytest.mark.parametrize("a_idx, a_path", make_test_params_for_eq_test()) +@pytest.mark.parametrize("b_idx, b_path", make_test_params_for_eq_test()) +def test_path_equal(a_idx, a_path, b_idx, b_path): + expected = a_idx == b_idx + result = a_path == b_path + assert result == expected + + +def test_path_get_prefix(): + p1 = ObjectPath.root() + p2 = p1.attr("foo") + p3 = p2.array_index(5) + + assert p3.parent == p2 + assert p2.parent == p1 + assert p1.parent is None + + assert p2.get_prefix(1) == p1 + + assert p3.get_prefix(1) == p1 + assert p3.get_prefix(2) == p2 + assert p3.get_prefix(3) == p3 + + with pytest.raises(IndexError) as e: + p3.get_prefix(0) + assert "Prefix length must be at least 1" in str(e.value) + + with pytest.raises(IndexError) as e: + p3.get_prefix(4) + assert "Attempted to get a prefix longer than the path itself" in str(e.value)