diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h index f8e63ed5857a..cf84b9a3a641 100644 --- a/include/tvm/arith/bound.h +++ b/include/tvm/arith/bound.h @@ -25,7 +25,6 @@ #include #include -#include #include #include diff --git a/include/tvm/arith/pattern.h b/include/tvm/arith/pattern.h index 3f1096b10a8b..5e1165d509c4 100644 --- a/include/tvm/arith/pattern.h +++ b/include/tvm/arith/pattern.h @@ -25,7 +25,6 @@ #define TVM_ARITH_PATTERN_H_ #include -#include #include namespace tvm { diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h index caff37cbf6d2..0ca14c43eb47 100755 --- a/include/tvm/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -50,7 +50,6 @@ #include #include -#include #include #include diff --git a/include/tvm/ir/adt.h b/include/tvm/ir/adt.h index 231c04e69821..50e9bcbab273 100644 --- a/include/tvm/ir/adt.h +++ b/include/tvm/ir/adt.h @@ -30,7 +30,9 @@ #include #include #include -#include +#include +#include +#include #include #include diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 2295baa0297b..b910d32ceca4 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 5b9e0714e202..c1a012f05318 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -26,7 +26,9 @@ #include #include -#include +#include +#include +#include #include #include diff --git a/include/tvm/ir/instrument.h b/include/tvm/ir/instrument.h index 1b0e9a9ea50e..1b9eb9c1b7c8 100644 --- a/include/tvm/ir/instrument.h +++ b/include/tvm/ir/instrument.h @@ -27,7 +27,7 @@ #define TVM_IR_INSTRUMENT_H_ #include -#include +#include #include #include diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 07d582a298e4..638f132e3179 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -29,7 +29,9 @@ #include #include #include -#include +#include +#include +#include #include #include diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 849eda6cd248..ce5ae280e176 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -60,7 +60,8 @@ #include #include #include -#include +#include +#include #include #include diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 4e4e009b2875..c772650809fa 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -51,7 +51,7 @@ #include #include -#include +#include #include #include diff --git a/include/tvm/node/attr_registry_map.h b/include/tvm/node/attr_registry_map.h index 6acd2e7dbdd8..c4b54ef0f27d 100644 --- a/include/tvm/node/attr_registry_map.h +++ b/include/tvm/node/attr_registry_map.h @@ -23,7 +23,7 @@ #ifndef TVM_NODE_ATTR_REGISTRY_MAP_H_ #define TVM_NODE_ATTR_REGISTRY_MAP_H_ -#include +#include #include #include diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 7b2a9f8061b4..ad4fb1e1c27a 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -39,7 +39,6 @@ #include #include #include -#include #include #include diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index d5309bca894d..6c25c3d2d21d 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -24,7 +24,7 @@ #define TVM_NODE_STRUCTURAL_EQUAL_H_ #include -#include +#include #include #include diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index a661a852780d..887a012cfc93 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -24,7 +24,6 @@ #define TVM_NODE_STRUCTURAL_HASH_H_ #include -#include #include #include diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 15f6b03f0c06..a58bb8750c14 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -29,8 +29,6 @@ #include -#include "tvm/runtime/container.h" - namespace tvm { namespace relay { diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index 4a5de33af4b9..751593f94cc0 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -26,7 +26,6 @@ #include #include -#include #include #include diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index e3fd5ae77193..93a56cede77b 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -36,7 +36,7 @@ #include #include -#include +#include #include #include diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 123b7e395faa..b090e3e40063 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -30,7 +30,6 @@ #include #include #include -#include #include #include diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h deleted file mode 100644 index edceabc3525a..000000000000 --- a/include/tvm/runtime/container.h +++ /dev/null @@ -1,3124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/container.h - * \brief Common POD(plain old data) container types. - */ -#ifndef TVM_RUNTIME_CONTAINER_H_ -#define TVM_RUNTIME_CONTAINER_H_ - -#ifndef USE_FALLBACK_STL_MAP -#define USE_FALLBACK_STL_MAP 0 -#endif - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -// We use c++14 std::experimental::string_view for optimizing hash computation -// only right now, its usage is limited in this file. Any broader usage of -// std::experiment in our core codebase is discouraged and needs community -// discussion for each use case. Reference for feature test macros of -// string_view: -// https://isocpp.org/std/standing-documents/sd-6-sg10-feature-test-recommendations -// https://en.cppreference.com/w/User:D41D8CD98F/feature_testing_macros -#if defined(__cpp_lib_experimental_string_view) && __cpp_lib_experimental_string_view >= 201411 -#define TVM_USE_CXX14_STRING_VIEW_HASH 1 -#else -#define TVM_USE_CXX14_STRING_VIEW_HASH 0 -#endif - -// Tested with clang version 9.0.1 and c++17. It will detect string_view support -// correctly. -#if defined(__cpp_lib_string_view) && __cpp_lib_string_view >= 201606 -#define TVM_USE_CXX17_STRING_VIEW_HASH 1 -#else -#define TVM_USE_CXX17_STRING_VIEW_HASH 0 -#endif - -#if TVM_USE_CXX17_STRING_VIEW_HASH -#include -#elif TVM_USE_CXX14_STRING_VIEW_HASH -#include -#endif - -#include -#include -#include - -namespace llvm { -// String to llvm object compatibility. -class StringRef; -} // namespace llvm - -namespace tvm { -namespace runtime { - -// Forward declare TVMArgValue -class TVMArgValue; - -/*! \brief String-aware ObjectRef equal functor */ -struct ObjectHash { - /*! - * \brief Calculate the hash code of an ObjectRef - * \param a The given ObjectRef - * \return Hash code of a, string hash for strings and pointer address otherwise. - */ - size_t operator()(const ObjectRef& a) const; -}; - -/*! \brief String-aware ObjectRef hash functor */ -struct ObjectEqual { - /*! - * \brief Check if the two ObjectRef are equal - * \param a One ObjectRef - * \param b The other ObjectRef - * \return String equality if both are strings, pointer address equality otherwise. - */ - bool operator()(const ObjectRef& a, const ObjectRef& b) const; -}; - -/*! - * \brief Base template for classes with array like memory layout. - * - * It provides general methods to access the memory. The memory - * layout is ArrayType + [ElemType]. The alignment of ArrayType - * and ElemType is handled by the memory allocator. - * - * \tparam ArrayType The array header type, contains object specific metadata. - * \tparam ElemType The type of objects stored in the array right after - * ArrayType. - * - * \code - * // Example usage of the template to define a simple array wrapper - * class ArrayObj : public InplaceArrayBase { - * public: - * // Wrap EmplaceInit to initialize the elements - * template - * void Init(Iterator begin, Iterator end) { - * size_t num_elems = std::distance(begin, end); - * auto it = begin; - * this->size = 0; - * for (size_t i = 0; i < num_elems; ++i) { - * InplaceArrayBase::EmplaceInit(i, *it++); - * this->size++; - * } - * } - * } - * - * void test_function() { - * vector fields; - * auto ptr = make_inplace_array_object(fields.size()); - * ptr->Init(fields.begin(), fields.end()); - * - * // Access the 0th element in the array. - * assert(ptr->operator[](0) == fields[0]); - * } - * - * \endcode - */ -template -class InplaceArrayBase { - public: - /*! - * \brief Access element at index - * \param idx The index of the element. - * \return Const reference to ElemType at the index. - */ - const ElemType& operator[](size_t idx) const { - size_t size = Self()->GetSize(); - ICHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n"; - return *(reinterpret_cast(AddressOf(idx))); - } - - /*! - * \brief Access element at index - * \param idx The index of the element. - * \return Reference to ElemType at the index. - */ - ElemType& operator[](size_t idx) { - size_t size = Self()->GetSize(); - ICHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n"; - return *(reinterpret_cast(AddressOf(idx))); - } - - /*! - * \brief Destroy the Inplace Array Base object - */ - ~InplaceArrayBase() { - if (!(std::is_standard_layout::value && std::is_trivial::value)) { - size_t size = Self()->GetSize(); - for (size_t i = 0; i < size; ++i) { - ElemType* fp = reinterpret_cast(AddressOf(i)); - fp->ElemType::~ElemType(); - } - } - } - - protected: - /*! - * \brief Construct a value in place with the arguments. - * - * \tparam Args Type parameters of the arguments. - * \param idx Index of the element. - * \param args Arguments to construct the new value. - * - * \note Please make sure ArrayType::GetSize returns 0 before first call of - * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds. - */ - template - void EmplaceInit(size_t idx, Args&&... args) { - void* field_ptr = AddressOf(idx); - new (field_ptr) ElemType(std::forward(args)...); - } - - /*! - * \brief Return the self object for the array. - * - * \return Pointer to ArrayType. - */ - inline ArrayType* Self() const { - return static_cast(const_cast(this)); - } - - /*! - * \brief Return the raw pointer to the element at idx. - * - * \param idx The index of the element. - * \return Raw pointer to the element. - */ - void* AddressOf(size_t idx) const { - static_assert( - alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, - "The size and alignment of ArrayType should respect " - "ElemType's alignment."); - - size_t kDataStart = sizeof(ArrayType); - ArrayType* self = Self(); - char* data_start = reinterpret_cast(self) + kDataStart; - return data_start + idx * sizeof(ElemType); - } -}; - -/*! - * \brief iterator adapter that adapts TIter to return another type. - * \tparam Converter a struct that contains converting function - * \tparam TIter the content iterator type. - */ -template -class IterAdapter { - public: - using difference_type = typename std::iterator_traits::difference_type; - using value_type = typename Converter::ResultType; - using pointer = typename Converter::ResultType*; - using reference = typename Converter::ResultType&; - using iterator_category = typename std::iterator_traits::iterator_category; - - explicit IterAdapter(TIter iter) : iter_(iter) {} - IterAdapter& operator++() { - ++iter_; - return *this; - } - IterAdapter& operator--() { - --iter_; - return *this; - } - IterAdapter operator++(int) { - IterAdapter copy = *this; - ++iter_; - return copy; - } - IterAdapter operator--(int) { - IterAdapter copy = *this; - --iter_; - return copy; - } - - IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } - - IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); } - - template - typename std::enable_if::value, - typename T::difference_type>::type inline - operator-(const IterAdapter& rhs) const { - return iter_ - rhs.iter_; - } - - bool operator==(IterAdapter other) const { return iter_ == other.iter_; } - bool operator!=(IterAdapter other) const { return !(*this == other); } - const value_type operator*() const { return Converter::convert(*iter_); } - - private: - TIter iter_; -}; - -/*! - * \brief iterator adapter that adapts TIter to return another type. - * \tparam Converter a struct that contains converting function - * \tparam TIter the content iterator type. - */ -template -class ReverseIterAdapter { - public: - using difference_type = typename std::iterator_traits::difference_type; - using value_type = typename Converter::ResultType; - using pointer = typename Converter::ResultType*; - using reference = typename Converter::ResultType&; // NOLINT(*) - using iterator_category = typename std::iterator_traits::iterator_category; - - explicit ReverseIterAdapter(TIter iter) : iter_(iter) {} - ReverseIterAdapter& operator++() { - --iter_; - return *this; - } - ReverseIterAdapter& operator--() { - ++iter_; - return *this; - } - ReverseIterAdapter& operator++(int) { - ReverseIterAdapter copy = *this; - --iter_; - return copy; - } - ReverseIterAdapter& operator--(int) { - ReverseIterAdapter copy = *this; - ++iter_; - return copy; - } - ReverseIterAdapter operator+(difference_type offset) const { - return ReverseIterAdapter(iter_ - offset); - } - - template - typename std::enable_if::value, - typename T::difference_type>::type inline - operator-(const ReverseIterAdapter& rhs) const { - return rhs.iter_ - iter_; - } - - bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; } - bool operator!=(ReverseIterAdapter other) const { return !(*this == other); } - const value_type operator*() const { return Converter::convert(*iter_); } - - private: - TIter iter_; -}; - -/*! \brief array node content in array */ -class ArrayNode : public Object, public InplaceArrayBase { - public: - /*! \return The size of the array */ - size_t size() const { return this->size_; } - - /*! - * \brief Read i-th element from array. - * \param i The index - * \return the i-th element. - */ - const ObjectRef at(int64_t i) const { return this->operator[](i); } - - /*! \return begin constant iterator */ - const ObjectRef* begin() const { return static_cast(InplaceArrayBase::AddressOf(0)); } - - /*! \return end constant iterator */ - const ObjectRef* end() const { return begin() + size_; } - - /*! \brief Release reference to all the elements */ - void clear() { ShrinkBy(size_); } - - /*! - * \brief Set i-th element of the array in-place - * \param i The index - * \param item The value to be set - */ - void SetItem(int64_t i, ObjectRef item) { this->operator[](i) = std::move(item); } - - /*! - * \brief Constructs a container and copy from another - * \param cap The capacity of the container - * \param from Source of the copy - * \return Ref-counted ArrayNode requested - */ - static ObjectPtr CopyFrom(int64_t cap, ArrayNode* from) { - int64_t size = from->size_; - ICHECK_GE(cap, size) << "ValueError: not enough capacity"; - ObjectPtr p = ArrayNode::Empty(cap); - ObjectRef* write = p->MutableBegin(); - ObjectRef* read = from->MutableBegin(); - // To ensure exception safety, size is only incremented after the initialization succeeds - for (int64_t& i = p->size_ = 0; i < size; ++i) { - new (write++) ObjectRef(*read++); - } - return p; - } - - /*! - * \brief Constructs a container and move from another - * \param cap The capacity of the container - * \param from Source of the move - * \return Ref-counted ArrayNode requested - */ - static ObjectPtr MoveFrom(int64_t cap, ArrayNode* from) { - int64_t size = from->size_; - ICHECK_GE(cap, size) << "ValueError: not enough capacity"; - ObjectPtr p = ArrayNode::Empty(cap); - ObjectRef* write = p->MutableBegin(); - ObjectRef* read = from->MutableBegin(); - // To ensure exception safety, size is only incremented after the initialization succeeds - for (int64_t& i = p->size_ = 0; i < size; ++i) { - new (write++) ObjectRef(std::move(*read++)); - } - from->size_ = 0; - return p; - } - - /*! - * \brief Constructs a container with n elements. Each element is a copy of val - * \param n The size of the container - * \param val The init value - * \return Ref-counted ArrayNode requested - */ - static ObjectPtr CreateRepeated(int64_t n, const ObjectRef& val) { - ObjectPtr p = ArrayNode::Empty(n); - ObjectRef* itr = p->MutableBegin(); - for (int64_t& i = p->size_ = 0; i < n; ++i) { - new (itr++) ObjectRef(val); - } - return p; - } - - static constexpr const uint32_t _type_index = TypeIndex::kRuntimeArray; - static constexpr const char* _type_key = "Array"; - TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); - - private: - /*! \return Size of initialized memory, used by InplaceArrayBase. */ - size_t GetSize() const { return this->size_; } - - /*! \return begin mutable iterator */ - ObjectRef* MutableBegin() const { - return static_cast(InplaceArrayBase::AddressOf(0)); - } - - /*! \return end mutable iterator */ - ObjectRef* MutableEnd() const { return MutableBegin() + size_; } - - /*! - * \brief Create an ArrayNode with the given capacity. - * \param n Required capacity - * \return Ref-counted ArrayNode requested - */ - static ObjectPtr Empty(int64_t n = kInitSize) { - ICHECK_GE(n, 0); - ObjectPtr p = make_inplace_array_object(n); - p->capacity_ = n; - p->size_ = 0; - return p; - } - - /*! - * \brief Inplace-initialize the elements starting idx from [first, last) - * \param idx The starting point - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return Self - */ - template - ArrayNode* InitRange(int64_t idx, IterType first, IterType last) { - ObjectRef* itr = MutableBegin() + idx; - for (; first != last; ++first) { - ObjectRef ref = *first; - new (itr++) ObjectRef(std::move(ref)); - } - return this; - } - - /*! - * \brief Move elements from right to left, requires src_begin > dst - * \param dst Destination - * \param src_begin The start point of copy (inclusive) - * \param src_end The end point of copy (exclusive) - * \return Self - */ - ArrayNode* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { - ObjectRef* from = MutableBegin() + src_begin; - ObjectRef* to = MutableBegin() + dst; - while (src_begin++ != src_end) { - *to++ = std::move(*from++); - } - return this; - } - - /*! - * \brief Move elements from left to right, requires src_begin < dst - * \param dst Destination - * \param src_begin The start point of move (inclusive) - * \param src_end The end point of move (exclusive) - * \return Self - */ - ArrayNode* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { - ObjectRef* from = MutableBegin() + src_end; - ObjectRef* to = MutableBegin() + (src_end - src_begin + dst); - while (src_begin++ != src_end) { - *--to = std::move(*--from); - } - return this; - } - - /*! - * \brief Enlarges the size of the array - * \param delta Size enlarged, should be positive - * \param val Default value - * \return Self - */ - ArrayNode* EnlargeBy(int64_t delta, const ObjectRef& val = ObjectRef(nullptr)) { - ObjectRef* itr = MutableEnd(); - while (delta-- > 0) { - new (itr++) ObjectRef(val); - ++size_; - } - return this; - } - - /*! - * \brief Shrinks the size of the array - * \param delta Size shrinked, should be positive - * \return Self - */ - ArrayNode* ShrinkBy(int64_t delta) { - ObjectRef* itr = MutableEnd(); - while (delta-- > 0) { - (--itr)->ObjectRef::~ObjectRef(); - --size_; - } - return this; - } - - /*! \brief Number of elements used */ - int64_t size_; - - /*! \brief Number of elements allocated */ - int64_t capacity_; - - /*! \brief Initial size of ArrayNode */ - static constexpr int64_t kInitSize = 4; - - /*! \brief Expansion factor of the Array */ - static constexpr int64_t kIncFactor = 2; - - // CRTP parent class - friend InplaceArrayBase; - - // Reference class - template - friend class Array; - - // To specialize make_object - friend ObjectPtr make_object<>(); -}; - -/*! - * \brief Array, container representing a contigious sequence of ObjectRefs. - * - * Array implements in-place copy-on-write semantics. - * - * As in typical copy-on-write, a method which would typically mutate the array - * instead opaquely copies the underlying container, and then acts on its copy. - * - * If the array has reference count equal to one, we directly update the - * container in place without copying. This is optimization is sound because - * when the reference count is equal to one this reference is guranteed to be - * the sole pointer to the container. - * - * - * operator[] only provides const access, use Set to mutate the content. - * \tparam T The content ObjectRef type. - */ -template ::value>::type> -class Array : public ObjectRef { - public: - using value_type = T; - // constructors - /*! - * \brief default constructor - */ - Array() { data_ = ArrayNode::Empty(); } - - /*! - * \brief move constructor - * \param other source - */ - Array(Array&& other) : ObjectRef() { // NOLINT(*) - data_ = std::move(other.data_); - } - - /*! - * \brief copy constructor - * \param other source - */ - Array(const Array& other) : ObjectRef() { // NOLINT(*) - data_ = other.data_; - } - - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Array(ObjectPtr n) : ObjectRef(n) {} - - /*! - * \brief Constructor from iterator - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - */ - template - Array(IterType first, IterType last) { - Assign(first, last); - } - - /*! - * \brief constructor from initializer list - * \param init The initializer list - */ - Array(std::initializer_list init) { // NOLINT(*) - Assign(init.begin(), init.end()); - } - - /*! - * \brief constructor from vector - * \param init The vector - */ - Array(const std::vector& init) { // NOLINT(*) - Assign(init.begin(), init.end()); - } - - /*! - * \brief Constructs a container with n elements. Each element is a copy of val - * \param n The size of the container - * \param val The init value - */ - explicit Array(const size_t n, const T& val) { data_ = ArrayNode::CreateRepeated(n, val); } - - /*! - * \brief move assign operator - * \param other The source of assignment - * \return reference to self. - */ - Array& operator=(Array&& other) { - data_ = std::move(other.data_); - return *this; - } - - /*! - * \brief copy assign operator - * \param other The source of assignment - * \return reference to self. - */ - Array& operator=(const Array& other) { - data_ = other.data_; - return *this; - } - - public: - // iterators - struct ValueConverter { - using ResultType = T; - static T convert(const ObjectRef& n) { return DowncastNoCheck(n); } - }; - - using iterator = IterAdapter; - using reverse_iterator = ReverseIterAdapter; - - /*! \return begin iterator */ - iterator begin() const { return iterator(GetArrayNode()->begin()); } - - /*! \return end iterator */ - iterator end() const { return iterator(GetArrayNode()->end()); } - - /*! \return rbegin iterator */ - reverse_iterator rbegin() const { - // ArrayNode::end() is never nullptr - return reverse_iterator(GetArrayNode()->end() - 1); - } - - /*! \return rend iterator */ - reverse_iterator rend() const { - // ArrayNode::begin() is never nullptr - return reverse_iterator(GetArrayNode()->begin() - 1); - } - - public: - // const methods in std::vector - /*! - * \brief Immutably read i-th element from array. - * \param i The index - * \return the i-th element. - */ - const T operator[](int64_t i) const { - ArrayNode* p = GetArrayNode(); - ICHECK(p != nullptr) << "ValueError: cannot index a null array"; - ICHECK(0 <= i && i < p->size_) - << "IndexError: indexing " << i << " on an array of size " << p->size_; - return DowncastNoCheck(*(p->begin() + i)); - } - - /*! \return The size of the array */ - size_t size() const { - ArrayNode* p = GetArrayNode(); - return p == nullptr ? 0 : GetArrayNode()->size_; - } - - /*! \return The capacity of the array */ - size_t capacity() const { - ArrayNode* p = GetArrayNode(); - return p == nullptr ? 0 : GetArrayNode()->capacity_; - } - - /*! \return Whether array is empty */ - bool empty() const { return size() == 0; } - - /*! \return The first element of the array */ - const T front() const { - ArrayNode* p = GetArrayNode(); - ICHECK(p != nullptr) << "ValueError: cannot index a null array"; - ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; - return DowncastNoCheck(*(p->begin())); - } - - /*! \return The last element of the array */ - const T back() const { - ArrayNode* p = GetArrayNode(); - ICHECK(p != nullptr) << "ValueError: cannot index a null array"; - ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; - return DowncastNoCheck(*(p->end() - 1)); - } - - public: - // mutation in std::vector, implements copy-on-write - - /*! - * \brief push a new item to the back of the list - * \param item The item to be pushed. - */ - void push_back(const T& item) { - ArrayNode* p = CopyOnWrite(1); - p->EmplaceInit(p->size_++, item); - } - - /*! - * \brief Insert an element into the given position - * \param position An iterator pointing to the insertion point - * \param val The element to insert - */ - void insert(iterator position, const T& val) { - ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; - int64_t idx = std::distance(begin(), position); - int64_t size = GetArrayNode()->size_; - auto addr = CopyOnWrite(1) // - ->EnlargeBy(1) // - ->MoveElementsRight(idx + 1, idx, size) // - ->MutableBegin(); - new (addr + idx) ObjectRef(val); - } - - /*! - * \brief Insert a range of elements into the given position - * \param position An iterator pointing to the insertion point - * \param first The begin iterator of the range - * \param last The end iterator of the range - */ - template - void insert(iterator position, IterType first, IterType last) { - if (first == last) { - return; - } - ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; - int64_t idx = std::distance(begin(), position); - int64_t size = GetArrayNode()->size_; - int64_t numel = std::distance(first, last); - CopyOnWrite(numel) - ->EnlargeBy(numel) - ->MoveElementsRight(idx + numel, idx, size) - ->InitRange(idx, first, last); - } - - /*! \brief Remove the last item of the list */ - void pop_back() { - ICHECK(data_ != nullptr) << "ValueError: cannot pop_back because array is null"; - int64_t size = GetArrayNode()->size_; - ICHECK_GT(size, 0) << "ValueError: cannot pop_back because array is empty"; - CopyOnWrite()->ShrinkBy(1); - } - - /*! - * \brief Erase an element on the given position - * \param position An iterator pointing to the element to be erased - */ - void erase(iterator position) { - ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; - int64_t st = std::distance(begin(), position); - int64_t size = GetArrayNode()->size_; - ICHECK(0 <= st && st < size) << "ValueError: cannot erase at index " << st - << ", because Array size is " << size; - CopyOnWrite() // - ->MoveElementsLeft(st, st + 1, size) // - ->ShrinkBy(1); - } - - /*! - * \brief Erase a given range of elements - * \param first The begin iterator of the range - * \param last The end iterator of the range - */ - void erase(iterator first, iterator last) { - if (first == last) { - return; - } - ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; - int64_t size = GetArrayNode()->size_; - int64_t st = std::distance(begin(), first); - int64_t ed = std::distance(begin(), last); - ICHECK_LT(st, ed) << "ValueError: cannot erase array in range [" << st << ", " << ed << ")"; - ICHECK(0 <= st && st <= size && 0 <= ed && ed <= size) - << "ValueError: cannot erase array in range [" << st << ", " << ed << ")" - << ", because array size is " << size; - CopyOnWrite() // - ->MoveElementsLeft(st, ed, size) // - ->ShrinkBy(ed - st); - } - - /*! - * \brief Resize the array. - * \param n The new size. - */ - void resize(int64_t n) { - ICHECK_GE(n, 0) << "ValueError: cannot resize an Array to negative size"; - if (data_ == nullptr) { - SwitchContainer(n); - return; - } - int64_t size = GetArrayNode()->size_; - if (size < n) { - CopyOnWrite(n - size)->EnlargeBy(n - size); - } else if (size > n) { - CopyOnWrite()->ShrinkBy(size - n); - } - } - - /*! - * \brief Make sure the list has the capacity of at least n - * \param n lower bound of the capacity - */ - void reserve(int64_t n) { - if (data_ == nullptr || n > GetArrayNode()->capacity_) { - SwitchContainer(n); - } - } - - /*! \brief Release reference to all the elements */ - void clear() { - if (data_ != nullptr) { - ArrayNode* p = CopyOnWrite(); - p->clear(); - } - } - - public: - // Array's own methods - - /*! - * \brief set i-th element of the array. - * \param i The index - * \param value The value to be setted. - */ - void Set(int64_t i, T value) { - ArrayNode* p = this->CopyOnWrite(); - ICHECK(0 <= i && i < p->size_) - << "IndexError: indexing " << i << " on an array of size " << p->size_; - *(p->MutableBegin() + i) = std::move(value); - } - - /*! \return The underlying ArrayNode */ - ArrayNode* GetArrayNode() const { return static_cast(data_.get()); } - - /*! - * \brief Helper function to apply fmutate to mutate an array. - * \param fmutate The transformation function T -> T. - * \tparam F the type of the mutation function. - * \note This function performs copy on write optimization. - */ - template - void MutateByApply(F fmutate) { - if (data_ == nullptr) { - return; - } - struct StackFrame { - ArrayNode* p; - ObjectRef* itr; - int64_t i; - int64_t size; - }; - std::unique_ptr s = std::make_unique(); - s->p = GetArrayNode(); - s->itr = s->p->MutableBegin(); - s->i = 0; - s->size = s->p->size_; - if (!data_.unique()) { - // Loop invariant: keeps iterating when - // 1) data is not unique - // 2) no elements are actually mutated yet - for (; s->i < s->size; ++s->i, ++s->itr) { - T new_elem = fmutate(DowncastNoCheck(*s->itr)); - // do nothing when there is no mutation - if (new_elem.same_as(*s->itr)) { - continue; - } - // loop invariant breaks when the first real mutation happens - // we copy the elements into a new unique array - ObjectPtr copy = ArrayNode::CopyFrom(s->p->capacity_, s->p); - s->itr = copy->MutableBegin() + (s->i++); - *s->itr++ = std::move(new_elem); - data_ = std::move(copy); - // make sure `data_` is unique and break - break; - } - } - // when execution comes to this line, it is guaranteed that either - // 1) i == size - // or 2) data_.unique() is true - for (; s->i < s->size; ++s->i, ++s->itr) { - *s->itr = std::move(fmutate(std::move(DowncastNoCheck(std::move(*s->itr))))); - } - } - - /*! - * \brief reset the array to content from iterator. - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - */ - template - void Assign(IterType first, IterType last) { - int64_t cap = std::distance(first, last); - ICHECK_GE(cap, 0) << "ValueError: cannot construct an Array of negative size"; - ArrayNode* p = GetArrayNode(); - if (p != nullptr && data_.unique() && p->capacity_ >= cap) { - // do not have to make new space - p->clear(); - } else { - // create new space - data_ = ArrayNode::Empty(cap); - p = GetArrayNode(); - } - // To ensure exception safety, size is only incremented after the initialization succeeds - ObjectRef* itr = p->MutableBegin(); - for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) { - new (itr) ObjectRef(*first); - } - } - - /*! - * \brief Copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which ganrantees to be unique) - */ - ArrayNode* CopyOnWrite() { - if (data_ == nullptr) { - return SwitchContainer(ArrayNode::kInitSize); - } - if (!data_.unique()) { - return SwitchContainer(capacity()); - } - return static_cast(data_.get()); - } - - /*! \brief specify container node */ - using ContainerType = ArrayNode; - - private: - /*! - * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. - * \param reserve_extra Number of extra slots needed - * \return ArrayNode pointer to the unique copy - */ - ArrayNode* CopyOnWrite(int64_t reserve_extra) { - ArrayNode* p = GetArrayNode(); - if (p == nullptr) { - // necessary to get around the constexpr address issue before c++17 - const int64_t kInitSize = ArrayNode::kInitSize; - return SwitchContainer(std::max(kInitSize, reserve_extra)); - } - if (p->capacity_ >= p->size_ + reserve_extra) { - return CopyOnWrite(); - } - int64_t cap = p->capacity_ * ArrayNode::kIncFactor; - cap = std::max(cap, p->size_ + reserve_extra); - return SwitchContainer(cap); - } - - /*! - * \brief Move or copy the ArrayNode to new address with the given capacity - * \param capacity The capacity requirement of the new address - */ - ArrayNode* SwitchContainer(int64_t capacity) { - if (data_ == nullptr) { - data_ = ArrayNode::Empty(capacity); - } else if (data_.unique()) { - data_ = ArrayNode::MoveFrom(capacity, GetArrayNode()); - } else { - data_ = ArrayNode::CopyFrom(capacity, GetArrayNode()); - } - return static_cast(data_.get()); - } -}; - -/*! - * \brief Concat two Arrays. - * \param lhs first Array to be concatenated. - * \param rhs second Array to be concatenated. - * \return The concatenated Array. Original Arrays are kept unchanged. - */ -template ::value>::type> -inline Array Concat(Array lhs, const Array& rhs) { - for (const auto& x : rhs) { - lhs.push_back(x); - } - return std::move(lhs); -} - -// Specialize make_object to make sure it is correct. -template <> -inline ObjectPtr make_object() { - return ArrayNode::Empty(); -} - -/*! \brief An object representing a structure or enumeration. */ -class ADTObj : public Object, public InplaceArrayBase { - public: - /*! \brief The tag representing the constructor used. */ - int32_t tag; - /*! \brief Number of fields in the ADT object. */ - uint32_t size; - // The fields of the structure follows directly in memory. - - static constexpr const uint32_t _type_index = TypeIndex::kRuntimeADT; - static constexpr const char* _type_key = "runtime.ADT"; - TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); - - private: - /*! - * \return The number of elements in the array. - */ - size_t GetSize() const { return size; } - - /*! - * \brief Initialize the elements in the array. - * - * \tparam Iterator Iterator type of the array. - * \param begin The begin iterator. - * \param end The end iterator. - */ - template - void Init(Iterator begin, Iterator end) { - size_t num_elems = std::distance(begin, end); - this->size = 0; - auto it = begin; - for (size_t i = 0; i < num_elems; ++i) { - InplaceArrayBase::EmplaceInit(i, *it++); - // Only increment size after the initialization succeeds - this->size++; - } - } - - friend class ADT; - friend InplaceArrayBase; -}; - -/*! \brief reference to algebraic data type objects. */ -class ADT : public ObjectRef { - public: - /*! - * \brief construct an ADT object reference. - * \param tag The tag of the ADT object. - * \param fields The fields of the ADT object. - * \return The constructed ADT object reference. - */ - ADT(int32_t tag, std::vector fields) : ADT(tag, fields.begin(), fields.end()){}; - - /*! - * \brief construct an ADT object reference. - * \param tag The tag of the ADT object. - * \param begin The begin iterator to the start of the fields array. - * \param end The end iterator to the end of the fields array. - * \return The constructed ADT object reference. - */ - template - ADT(int32_t tag, Iterator begin, Iterator end) { - size_t num_elems = std::distance(begin, end); - auto ptr = make_inplace_array_object(num_elems); - ptr->tag = tag; - ptr->Init(begin, end); - data_ = std::move(ptr); - } - - /*! - * \brief construct an ADT object reference. - * \param tag The tag of the ADT object. - * \param init The initializer list of fields. - * \return The constructed ADT object reference. - */ - ADT(int32_t tag, std::initializer_list init) : ADT(tag, init.begin(), init.end()){}; - - /*! - * \brief Access element at index. - * - * \param idx The array index - * \return const ObjectRef - */ - const ObjectRef& operator[](size_t idx) const { return operator->()->operator[](idx); } - - /*! - * \brief Return the ADT tag. - */ - int32_t tag() const { return operator->()->tag; } - - /*! - * \brief Return the number of fields. - */ - size_t size() const { return operator->()->size; } - - /*! - * \brief Construct a tuple object. - * - * \tparam Args Type params of tuple feilds. - * \param args Tuple fields. - * \return ADT The tuple object reference. - */ - template - static ADT Tuple(Args&&... args) { - return ADT(0, std::forward(args)...); - } - - TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj); -}; - -/*! \brief An object representing string. It's POD type. */ -class StringObj : public Object { - public: - /*! \brief The pointer to string data. */ - const char* data; - - /*! \brief The length of the string object. */ - uint64_t size; - - static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString; - static constexpr const char* _type_key = "runtime.String"; - TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object); - - private: - /*! \brief String object which is moved from std::string container. */ - class FromStd; - - friend class String; -}; - -/*! - * \brief Reference to string objects. - * - * \code - * - * // Example to create runtime String reference object from std::string - * std::string s = "hello world"; - * - * // You can create the reference from existing std::string - * String ref{std::move(s)}; - * - * // You can rebind the reference to another string. - * ref = std::string{"hello world2"}; - * - * // You can use the reference as hash map key - * std::unordered_map m; - * m[ref] = 1; - * - * // You can compare the reference object with other string objects - * assert(ref == "hello world", true); - * - * // You can convert the reference to std::string again - * string s2 = (string)ref; - * - * \endcode - */ -class String : public ObjectRef { - public: - /*! - * \brief Construct an empty string. - */ - String() : String(std::string()) {} - /*! - * \brief Construct a new String object - * - * \param other The moved/copied std::string object - * - * \note If user passes const reference, it will trigger copy. If it's rvalue, - * it will be moved into other. - */ - String(std::string other); // NOLINT(*) - - /*! - * \brief Construct a new String object - * - * \param other a char array. - */ - String(const char* other) // NOLINT(*) - : String(std::string(other)) {} - - /*! - * \brief Change the value the reference object points to. - * - * \param other The value for the new String - * - */ - inline String& operator=(std::string other); - - /*! - * \brief Change the value the reference object points to. - * - * \param other The value for the new String - */ - inline String& operator=(const char* other); - - /*! - * \brief Compares this String object to other - * - * \param other The String to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const String& other) const { - return memncmp(data(), other.data(), size(), other.size()); - } - - /*! - * \brief Compares this String object to other - * - * \param other The string to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const std::string& other) const { - return memncmp(data(), other.data(), size(), other.size()); - } - - /*! - * \brief Compares this to other - * - * \param other The character array to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const char* other) const { - return memncmp(data(), other, size(), std::strlen(other)); - } - - /*! - * \brief Returns a pointer to the char array in the string. - * - * \return const char* - */ - const char* c_str() const { return get()->data; } - - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t size() const { - const auto* ptr = get(); - return ptr->size; - } - - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t length() const { return size(); } - - /*! - * \brief Retun if the string is empty - * - * \return true if empty, false otherwise. - */ - bool empty() const { return size() == 0; } - - /*! - * \brief Read an element. - * \param pos The position at which to read the character. - * - * \return The char at position - */ - char at(size_t pos) const { - if (pos < size()) { - return data()[pos]; - } else { - throw std::out_of_range("tvm::String index out of bounds"); - } - } - - /*! - * \brief Return the data pointer - * - * \return const char* data pointer - */ - const char* data() const { return get()->data; } - - /*! - * \brief Convert String to an std::string object - * - * \return std::string - */ - operator std::string() const { return std::string{get()->data, size()}; } - - // LLVM compatibility function, implemented in src/target/llvm/llvm_common.h - /*! - * \brief Convert String to an llvm::StringRef object - * - * \return llvm::StringRef - */ - inline operator llvm::StringRef() const; - - /*! - * \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String - * \param val The value to be checked - * \return A boolean indicating if val can be converted to String - */ - inline static bool CanConvertFrom(const TVMArgValue& val); - - /*! - * \brief Hash the binary bytes - * \param data The data pointer - * \param size The size of the bytes. - * \return the hash value. - */ - static size_t HashBytes(const char* data, size_t size) { - // This function falls back to string copy with c++11 compiler and is - // recommended to be compiled with c++14 -#if TVM_USE_CXX17_STRING_VIEW_HASH - return std::hash()(std::string_view(data, size)); -#elif TVM_USE_CXX14_STRING_VIEW_HASH - return std::hash()(std::experimental::string_view(data, size)); -#else - return std::hash()(std::string(data, size)); -#endif - } - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); - - private: - /*! - * \brief Compare two char sequence - * - * \param lhs Pointers to the char array to compare - * \param rhs Pointers to the char array to compare - * \param lhs_count Length of the char array to compare - * \param rhs_count Length of the char array to compare - * \return int zero if both char sequences compare equal. negative if this - * appear before other, positive otherwise. - */ - static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count); - - /*! - * \brief Concatenate two char sequences - * - * \param lhs Pointers to the lhs char array - * \param lhs_size The size of the lhs char array - * \param rhs Pointers to the rhs char array - * \param rhs_size The size of the rhs char array - * - * \return The concatenated char sequence - */ - static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) { - std::string ret(lhs, lhs_size); - ret.append(rhs, rhs_size); - return String(ret); - } - - // Overload + operator - friend String operator+(const String& lhs, const String& rhs); - friend String operator+(const String& lhs, const std::string& rhs); - friend String operator+(const std::string& lhs, const String& rhs); - friend String operator+(const String& lhs, const char* rhs); - friend String operator+(const char* lhs, const String& rhs); - - friend struct tvm::runtime::ObjectEqual; -}; - -/*! \brief An object representing string moved from std::string. */ -class StringObj::FromStd : public StringObj { - public: - /*! - * \brief Construct a new FromStd object - * - * \param other The moved/copied std::string object - * - * \note If user passes const reference, it will trigger copy. If it's rvalue, - * it will be moved into other. - */ - explicit FromStd(std::string other) : data_container{other} {} - - private: - /*! \brief Container that holds the memory. */ - std::string data_container; - - friend class String; -}; - -inline String::String(std::string other) { - auto ptr = make_object(std::move(other)); - ptr->size = ptr->data_container.size(); - ptr->data = ptr->data_container.data(); - data_ = std::move(ptr); -} - -inline String& String::operator=(std::string other) { - String replace{std::move(other)}; - data_.swap(replace.data_); - return *this; -} - -inline String& String::operator=(const char* other) { return operator=(std::string(other)); } - -inline String operator+(const String& lhs, const String& rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const String& lhs, const std::string& rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const std::string& lhs, const String& rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const char* lhs, const String& rhs) { - size_t lhs_size = std::strlen(lhs); - size_t rhs_size = rhs.size(); - return String::Concat(lhs, lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const String& lhs, const char* rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = std::strlen(rhs); - return String::Concat(lhs.data(), lhs_size, rhs, rhs_size); -} - -// Overload < operator -inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; } - -inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; } - -// Overload > operator -inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; } - -inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; } - -// Overload <= operator -inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } - -inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } - -// Overload >= operator -inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } - -inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(rhs) <= 0; } - -// Overload == operator -inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; } - -inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; } - -inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } - -inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; } - -inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; } - -// Overload != operator -inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; } - -inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; } - -inline std::ostream& operator<<(std::ostream& out, const String& input) { - out.write(input.data(), input.size()); - return out; -} - -inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { - if (lhs == rhs && lhs_count == rhs_count) return 0; - - for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { - if (lhs[i] < rhs[i]) return -1; - if (lhs[i] > rhs[i]) return 1; - } - if (lhs_count < rhs_count) { - return -1; - } else if (lhs_count > rhs_count) { - return 1; - } else { - return 0; - } -} - -inline size_t ObjectHash::operator()(const ObjectRef& a) const { - if (const auto* str = a.as()) { - return String::HashBytes(str->data, str->size); - } - return ObjectPtrHash()(a); -} - -inline bool ObjectEqual::operator()(const ObjectRef& a, const ObjectRef& b) const { - if (a.same_as(b)) { - return true; - } - if (const auto* str_a = a.as()) { - if (const auto* str_b = b.as()) { - return String::memncmp(str_a->data, str_b->data, str_a->size, str_b->size) == 0; - } - } - return false; -} - -/*! \brief Helper to represent nullptr for optional. */ -struct NullOptType {}; - -/*! - * \brief Optional container that to represent to a Nullable variant of T. - * \tparam T The original ObjectRef. - * - * \code - * - * Optional opt0 = nullptr; - * Optional opt1 = String("xyz"); - * ICHECK(opt0 == nullptr); - * ICHECK(opt1 == "xyz"); - * - * \endcode - */ -template -class Optional : public ObjectRef { - public: - using ContainerType = typename T::ContainerType; - static_assert(std::is_base_of::value, "Optional is only defined for ObjectRef."); - // default constructors. - Optional() = default; - Optional(const Optional&) = default; - Optional(Optional&&) = default; - Optional& operator=(const Optional&) = default; - Optional& operator=(Optional&&) = default; - /*! - * \brief Construct from an ObjectPtr - * whose type already matches the ContainerType. - * \param ptr - */ - explicit Optional(ObjectPtr ptr) : ObjectRef(ptr) {} - /*! \brief Nullopt handling */ - Optional(NullOptType) {} // NOLINT(*) - // nullptr handling. - // disallow implicit conversion as 0 can be implicitly converted to nullptr_t - explicit Optional(std::nullptr_t) {} - Optional& operator=(std::nullptr_t) { - data_ = nullptr; - return *this; - } - // normal value handling. - Optional(T other) // NOLINT(*) - : ObjectRef(std::move(other)) {} - Optional& operator=(T other) { - ObjectRef::operator=(std::move(other)); - return *this; - } - // delete the int constructor - // since Optional(0) is ambiguious - // 0 can be implicitly casted to nullptr_t - explicit Optional(int val) = delete; - Optional& operator=(int val) = delete; - /*! - * \return A not-null container value in the optional. - * \note This function performs not-null checking. - */ - T value() const { - ICHECK(data_ != nullptr); - return T(data_); - } - /*! - * \return The contained value if the Optional is not null - * otherwise return the default_value. - */ - T value_or(T default_value) const { return data_ != nullptr ? T(data_) : default_value; } - - /*! \return Whether the container is not nullptr.*/ - explicit operator bool() const { return *this != nullptr; } - // operator overloadings - bool operator==(std::nullptr_t) const { return data_ == nullptr; } - bool operator!=(std::nullptr_t) const { return data_ != nullptr; } - auto operator==(const Optional& other) const { - // support case where sub-class returns a symbolic ref type. - using RetType = decltype(value() == other.value()); - if (same_as(other)) return RetType(true); - if (*this != nullptr && other != nullptr) { - return value() == other.value(); - } else { - // one of them is nullptr. - return RetType(false); - } - } - auto operator!=(const Optional& other) const { - // support case where sub-class returns a symbolic ref type. - using RetType = decltype(value() != other.value()); - if (same_as(other)) return RetType(false); - if (*this != nullptr && other != nullptr) { - return value() != other.value(); - } else { - // one of them is nullptr. - return RetType(true); - } - } - auto operator==(const T& other) const { - using RetType = decltype(value() == other); - if (same_as(other)) return RetType(true); - if (*this != nullptr) return value() == other; - return RetType(false); - } - auto operator!=(const T& other) const { return !(*this == other); } - template - auto operator==(const U& other) const { - using RetType = decltype(value() == other); - if (*this == nullptr) return RetType(false); - return value() == other; - } - template - auto operator!=(const U& other) const { - using RetType = decltype(value() != other); - if (*this == nullptr) return RetType(true); - return value() != other; - } - static constexpr bool _type_is_nullable = true; -}; - -/*! - * \brief An object representing a closure. This object is used by both the - * Relay VM and interpreter. - */ -class ClosureObj : public Object { - public: - static constexpr const uint32_t _type_index = TypeIndex::kRuntimeClosure; - static constexpr const char* _type_key = "runtime.Closure"; - TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object); -}; - -/*! \brief reference to closure. */ -class Closure : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj); -}; - -#if (USE_FALLBACK_STL_MAP != 0) - -/*! \brief Shared content of all specializations of hash map */ -class MapNode : public Object { - public: - /*! \brief Type of the keys in the hash map */ - using key_type = ObjectRef; - /*! \brief Type of the values in the hash map */ - using mapped_type = ObjectRef; - /*! \brief Type of the actual underlying container */ - using ContainerType = std::unordered_map; - /*! \brief Iterator class */ - using iterator = ContainerType::iterator; - /*! \brief Iterator class */ - using const_iterator = ContainerType::const_iterator; - /*! \brief Type of value stored in the hash map */ - using KVType = ContainerType::value_type; - - static_assert(std::is_standard_layout::value, "KVType is not standard layout"); - static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); - - static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; - static constexpr const char* _type_key = "Map"; - TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); - - /*! - * \brief Number of elements in the SmallMapNode - * \return The result - */ - size_t size() const { return data_.size(); } - /*! - * \brief Count the number of times a key exists in the hash map - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const { return data_.count(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { return data_.at(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { return data_.at(key); } - /*! \return begin iterator */ - iterator begin() { return data_.begin(); } - /*! \return const begin iterator */ - const_iterator begin() const { return data_.begin(); } - /*! \return end iterator */ - iterator end() { return data_.end(); } - /*! \return end iterator */ - const_iterator end() const { return data_.end(); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - const_iterator find(const key_type& key) const { return data_.find(key); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) { return data_.find(key); } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { data_.erase(position); } - /*! - * \brief Erase the entry associated with the key, do nothing if not exists - * \param key The indexing key - */ - void erase(const key_type& key) { data_.erase(key); } - /*! - * \brief Create an empty container - * \return The object created - */ - static ObjectPtr Empty() { return make_object(); } - - protected: - /*! - * \brief Create the map using contents from the given iterators. - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return ObjectPtr to the map created - */ - template - static ObjectPtr CreateFromRange(IterType first, IterType last) { - ObjectPtr p = make_object(); - p->data_ = ContainerType(first, last); - return p; - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - MapNode* map_node = static_cast(map->get()); - map_node->data_[kv.first] = kv.second; - } - /*! - * \brief Create an empty container with elements copying from another MapNode - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(MapNode* from) { - ObjectPtr p = make_object(); - p->data_ = ContainerType(from->data_.begin(), from->data_.end()); - return p; - } - /*! \brief The real container storing data */ - ContainerType data_; - template - friend class Map; -}; - -#else - -/*! \brief Shared content of all specializations of hash map */ -class MapNode : public Object { - public: - /*! \brief Type of the keys in the hash map */ - using key_type = ObjectRef; - /*! \brief Type of the values in the hash map */ - using mapped_type = ObjectRef; - /*! \brief Type of value stored in the hash map */ - using KVType = std::pair; - /*! \brief Iterator class */ - class iterator; - - static_assert(std::is_standard_layout::value, "KVType is not standard layout"); - static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); - - static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; - static constexpr const char* _type_key = "Map"; - TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); - - /*! - * \brief Number of elements in the SmallMapNode - * \return The result - */ - size_t size() const { return size_; } - /*! - * \brief Count the number of times a key exists in the hash map - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key); - /*! \return begin iterator */ - iterator begin() const; - /*! \return end iterator */ - iterator end() const; - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const; - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position); - /*! - * \brief Erase the entry associated with the key, do nothing if not exists - * \param key The indexing key - */ - void erase(const key_type& key) { erase(find(key)); } - - class iterator { - public: - using iterator_category = std::forward_iterator_tag; - using difference_type = int64_t; - using value_type = KVType; - using pointer = KVType*; - using reference = KVType&; - /*! \brief Default constructor */ - iterator() : index(0), self(nullptr) {} - /*! \brief Compare iterators */ - bool operator==(const iterator& other) const { - return index == other.index && self == other.self; - } - /*! \brief Compare iterators */ - bool operator!=(const iterator& other) const { return !(*this == other); } - /*! \brief De-reference iterators */ - pointer operator->() const; - /*! \brief De-reference iterators */ - reference operator*() const { return *((*this).operator->()); } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator& operator++(); - /*! \brief Prefix self decrement, e.g. --iter */ - iterator& operator--(); - /*! \brief Suffix self increment */ - iterator operator++(int) { - iterator copy = *this; - ++(*this); - return copy; - } - /*! \brief Suffix self decrement */ - iterator operator--(int) { - iterator copy = *this; - --(*this); - return copy; - } - - protected: - /*! \brief Construct by value */ - iterator(uint64_t index, const MapNode* self) : index(index), self(self) {} - /*! \brief The position on the array */ - uint64_t index; - /*! \brief The container it points to */ - const MapNode* self; - - friend class DenseMapNode; - friend class SmallMapNode; - }; - /*! - * \brief Create an empty container - * \return The object created - */ - static inline ObjectPtr Empty(); - - protected: - /*! - * \brief Create the map using contents from the given iterators. - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return ObjectPtr to the map created - */ - template - static inline ObjectPtr CreateFromRange(IterType first, IterType last); - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static inline void InsertMaybeReHash(const KVType& kv, ObjectPtr* map); - /*! - * \brief Create an empty container with elements copying from another SmallMapNode - * \param from The source container - * \return The object created - */ - static inline ObjectPtr CopyFrom(MapNode* from); - /*! \brief number of slots minus 1 */ - uint64_t slots_; - /*! \brief number of entries in the container */ - uint64_t size_; - // Reference class - template - friend class Map; -}; - -/*! \brief A specialization of small-sized hash map */ -class SmallMapNode : public MapNode, - public runtime::InplaceArrayBase { - private: - static constexpr uint64_t kInitSize = 2; - static constexpr uint64_t kMaxSize = 4; - - public: - using MapNode::iterator; - using MapNode::KVType; - - /*! \brief Defaults to the destructor of InplaceArrayBase */ - ~SmallMapNode() = default; - /*! - * \brief Count the number of times a key exists in the SmallMapNode - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const { return find(key).index < size_; } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { - iterator itr = find(key); - ICHECK(itr.index < size_) << "IndexError: key is not in Map"; - return itr->second; - } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { - iterator itr = find(key); - ICHECK(itr.index < size_) << "IndexError: key is not in Map"; - return itr->second; - } - /*! \return begin iterator */ - iterator begin() const { return iterator(0, this); } - /*! \return end iterator */ - iterator end() const { return iterator(size_, this); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const { - KVType* ptr = static_cast(AddressOf(0)); - for (uint64_t i = 0; i < size_; ++i, ++ptr) { - if (ObjectEqual()(ptr->first, key)) { - return iterator(i, this); - } - } - return iterator(size_, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { Erase(position.index); } - - private: - /*! - * \brief Remove a position in SmallMapNode - * \param index The position to be removed - */ - void Erase(const uint64_t index) { - if (index >= size_) { - return; - } - KVType* begin = static_cast(AddressOf(0)); - KVType* last = begin + (size_ - 1); - if (index + 1 == size_) { - last->first.ObjectRef::~ObjectRef(); - last->second.ObjectRef::~ObjectRef(); - } else { - *(begin + index) = std::move(*last); - } - size_ -= 1; - } - /*! - * \brief Create an empty container - * \param n Number of empty slots - * \return The object created - */ - static ObjectPtr Empty(uint64_t n = kInitSize) { - using ::tvm::runtime::make_inplace_array_object; - ObjectPtr p = make_inplace_array_object(n); - p->size_ = 0; - p->slots_ = n; - return p; - } - /*! - * \brief Create an empty container initialized with a given range - * \param n Number of empty slots - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - * \return The object created - */ - template - static ObjectPtr CreateFromRange(uint64_t n, IterType first, IterType last) { - ObjectPtr p = Empty(n); - KVType* ptr = static_cast(p->AddressOf(0)); - for (; first != last; ++first, ++p->size_) { - new (ptr++) KVType(*first); - } - return p; - } - /*! - * \brief Create an empty container with elements copying from another SmallMapNode - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(SmallMapNode* from) { - KVType* first = static_cast(from->AddressOf(0)); - KVType* last = first + from->size_; - return CreateFromRange(from->size_, first, last); - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - SmallMapNode* map_node = static_cast(map->get()); - iterator itr = map_node->find(kv.first); - if (itr.index < map_node->size_) { - itr->second = kv.second; - return; - } - if (map_node->size_ < map_node->slots_) { - KVType* ptr = static_cast(map_node->AddressOf(map_node->size_)); - new (ptr) KVType(kv); - ++map_node->size_; - return; - } - uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize)); - next_size = std::min(next_size, uint64_t(kMaxSize)); - ICHECK_GT(next_size, map_node->slots_); - ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); - InsertMaybeReHash(kv, &new_map); - *map = std::move(new_map); - } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType* DeRefItr(uint64_t index) const { return static_cast(AddressOf(index)); } - /*! \brief A size function used by InplaceArrayBase */ - uint64_t GetSize() const { return size_; } - - protected: - friend class MapNode; - friend class DenseMapNode; - friend class runtime::InplaceArrayBase; -}; - -/*! \brief A specialization of hash map that implements the idea of array-based hash map. - * Another reference implementation can be found [1]. - * - * A. Overview - * - * DenseMapNode did several improvements over traditional separate chaining hash, - * in terms of cache locality, memory footprints and data organization. - * - * A1. Implicit linked list. For better cache locality, instead of using linked list - * explicitly for each bucket, we store list data into a single array that spans contiguously - * in memory, and then carefully design access patterns to make sure most of them fall into - * a single cache line. - * - * A2. 1-byte metadata. There is only 1 byte overhead for each slot in the array to indexing and - * traversal. This can be divided in 3 parts. - * 1) Reserved code: (0b11111111)_2 indicates a slot is empty; (0b11111110)_2 indicates protected, - * which means the slot is empty but not allowed to be written. - * 2) If not empty or protected, the highest bit is used to indicate whether data in the slot is - * head of a linked list. - * 3) The rest 7 bits are used as the "next pointer" (i.e. pointer to the next element). On 64-bit - * architecture, an ordinary pointer can take up to 8 bytes, which is not acceptable overhead when - * dealing with 16-byte ObjectRef pairs. Based on a commonly noticed fact that the lists are - * relatively short (length <= 3) in hash maps, we follow [1]'s idea that only allows the pointer to - * be one of the 126 possible values, i.e. if the next element of i-th slot is (i + x)-th element, - * then x must be one of the 126 pre-defined values. - * - * A3. Data blocking. We organize the array in the way that every 16 elements forms a data block. - * The 16-byte metadata of those 16 elements are stored together, followed by the real data, i.e. - * 16 key-value pairs. - * - * B. Implementation details - * - * B1. Power-of-2 table size and Fibonacci Hashing. We use power-of-two as table size to avoid - * modulo for more efficient arithmetics. To make the hash-to-slot mapping distribute more evenly, - * we use the Fibonacci Hashing [2] trick. - * - * B2. Traverse a linked list in the array. - * 1) List head. Assume Fibonacci Hashing maps a given key to slot i, if metadata at slot i - * indicates that it is list head, then we found the head; otherwise the list is empty. No probing - * is done in this procedure. 2) Next element. To find the next element of a non-empty slot i, we - * look at the last 7 bits of the metadata at slot i. If they are all zeros, then it is the end of - * list; otherwise, we know that the next element is (i + candidates[the-last-7-bits]). - * - * B3. InsertMaybeReHash an element. Following B2, we first traverse the linked list to see if this - * element is in the linked list, and if not, we put it at the end by probing the next empty - * position in one of the 126 candidate positions. If the linked list does not even exist, but the - * slot for list head has been occupied by another linked list, we should find this intruder another - * place. - * - * B4. Quadratic probing with triangle numbers. In open address hashing, it is provable that probing - * with triangle numbers can traverse power-of-2-sized table [3]. In our algorithm, we follow the - * suggestion in [1] that also use triangle numbers for "next pointer" as well as sparing for list - * head. - * - * [1] https://github.com/skarupke/flat_hash_map - * [2] https://programmingpraxis.com/2018/06/19/fibonacci-hash/ - * [3] https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ - */ -class DenseMapNode : public MapNode { - private: - /*! \brief The number of elements in a memory block */ - static constexpr int kBlockCap = 16; - /*! \brief Maximum load factor of the hash map */ - static constexpr double kMaxLoadFactor = 0.99; - /*! \brief Binary representation of the metadata of an empty slot */ - static constexpr uint8_t kEmptySlot = uint8_t(0b11111111); - /*! \brief Binary representation of the metadata of a protected slot */ - static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110); - /*! \brief Number of probing choices available */ - static constexpr int kNumJumpDists = 126; - /*! \brief Head of the implicit linked list */ - struct ListNode; - /*! \brief POD type of a block of memory */ - struct Block { - uint8_t bytes[kBlockCap + kBlockCap * sizeof(KVType)]; - }; - static_assert(sizeof(Block) == kBlockCap * (sizeof(KVType) + 1), "sizeof(Block) incorrect"); - static_assert(std::is_standard_layout::value, "Block is not standard layout"); - - public: - using MapNode::iterator; - - /*! - * \brief Destroy the DenseMapNode - */ - ~DenseMapNode() { this->Reset(); } - /*! \return The number of elements of the key */ - size_t count(const key_type& key) const { return !Search(key).IsNone(); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { return At(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { return At(key); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const { - ListNode node = Search(key); - return node.IsNone() ? end() : iterator(node.index, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { - uint64_t index = position.index; - if (position.self != nullptr && index <= this->slots_) { - Erase(ListNode(index, this)); - } - } - /*! \return begin iterator */ - iterator begin() const { - if (slots_ == 0) { - return iterator(0, this); - } - for (uint64_t index = 0; index <= slots_; ++index) { - if (!ListNode(index, this).IsEmpty()) { - return iterator(index, this); - } - } - return iterator(slots_ + 1, this); - } - /*! \return end iterator */ - iterator end() const { return slots_ == 0 ? iterator(0, this) : iterator(slots_ + 1, this); } - - private: - /*! - * \brief Search for the given key - * \param key The key - * \return ListNode that associated with the key - */ - ListNode Search(const key_type& key) const { - if (this->size_ == 0) { - return ListNode(); - } - for (ListNode iter = GetListHead(ObjectHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) { - if (ObjectEqual()(key, iter.Key())) { - return iter; - } - } - return ListNode(); - } - /*! - * \brief Search for the given key, throw exception if not exists - * \param key The key - * \return ListNode that associated with the key - */ - mapped_type& At(const key_type& key) const { - ListNode iter = Search(key); - ICHECK(!iter.IsNone()) << "IndexError: key is not in Map"; - return iter.Val(); - } - /*! - * \brief Try to insert a key, or do nothing if already exists - * \param key The indexing key - * \param result The linked-list entry found or just constructed - * \return A boolean, indicating if actual insertion happens - */ - bool TryInsert(const key_type& key, ListNode* result) { - if (slots_ == 0) { - return false; - } - // required that `iter` to be the head of a linked list through which we can iterator - ListNode iter = IndexFromHash(ObjectHash()(key)); - // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list - // Case 1: empty - if (iter.IsEmpty()) { - iter.NewHead(KVType(key, ObjectRef(nullptr))); - this->size_ += 1; - *result = iter; - return true; - } - // Case 2: body of an irrelevant list - if (!iter.IsHead()) { - // we move the elements around and construct the single-element linked list - return IsFull() ? false : TrySpareListHead(iter, key, result); - } - // Case 3: head of the relevant list - // we iterate through the linked list until the end - // make sure `iter` is the previous element of `next` - ListNode next = iter; - do { - // find equal item, do not insert - if (ObjectEqual()(key, next.Key())) { - *result = next; - return true; - } - // make sure `iter` is the previous element of `next` - iter = next; - } while (next.MoveToNext(this)); - // `iter` is the tail of the linked list - // always check capacity before insertion - if (IsFull()) { - return false; - } - // find the next empty slot - uint8_t jump; - if (!iter.GetNextEmpty(this, &jump, result)) { - return false; - } - result->NewTail(KVType(key, ObjectRef(nullptr))); - // link `iter` to `empty`, and move forward - iter.SetJump(jump); - this->size_ += 1; - return true; - } - /*! - * \brief Spare an entry to be the head of a linked list. - * As described in B3, during insertion, it is possible that the entire linked list does not - * exist, but the slot of its head has been occupied by other linked lists. In this case, we need - * to spare the slot by moving away the elements to another valid empty one to make insertion - * possible. - * \param target The given entry to be spared - * \param key The indexing key - * \param result The linked-list entry constructed as the head - * \return A boolean, if actual insertion happens - */ - bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) { - // `target` is not the head of the linked list - // move the original item of `target` (if any) - // and construct new item on the position `target` - // To make `target` empty, we - // 1) find `w` the previous element of `target` in the linked list - // 2) copy the linked list starting from `r = target` - // 3) paste them after `w` - // read from the linked list after `r` - ListNode r = target; - // write to the tail of `w` - ListNode w = target.FindPrev(this); - // after `target` is moved, we disallow writing to the slot - bool is_first = true; - uint8_t r_meta, jump; - ListNode empty; - do { - // `jump` describes how `w` is jumped to `empty` - // rehash if there is no empty space after `w` - if (!w.GetNextEmpty(this, &jump, &empty)) { - return false; - } - // move `r` to `empty` - empty.NewTail(std::move(r.Data())); - // clear the metadata of `r` - r_meta = r.Meta(); - if (is_first) { - is_first = false; - r.SetProtected(); - } else { - r.SetEmpty(); - } - // link `w` to `empty`, and move forward - w.SetJump(jump); - w = empty; - // move `r` forward as well - } while (r.MoveToNext(this, r_meta)); - // finally we have done moving the linked list - // fill data_ into `target` - target.NewHead(KVType(key, ObjectRef(nullptr))); - this->size_ += 1; - *result = target; - return true; - } - /*! - * \brief Remove a ListNode - * \param iter The node to be removed - */ - void Erase(const ListNode& iter) { - this->size_ -= 1; - if (!iter.HasNext()) { - // `iter` is the last - if (!iter.IsHead()) { - // cut the link if there is any - iter.FindPrev(this).SetJump(0); - } - iter.Data().KVType::~KVType(); - iter.SetEmpty(); - } else { - ListNode last = iter, prev = iter; - for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) { - } - iter.Data() = std::move(last.Data()); - last.SetEmpty(); - prev.SetJump(0); - } - } - /*! \brief Clear the container to empty, release all entries and memory acquired */ - void Reset() { - uint64_t n_blocks = CalcNumBlocks(this->slots_); - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr = data_[bi].bytes; - KVType* data_ptr = reinterpret_cast(data_[bi].bytes + kBlockCap); - for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { - uint8_t& meta = *meta_ptr; - if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { - meta = uint8_t(kEmptySlot); - data_ptr->KVType::~KVType(); - } - } - } - ReleaseMemory(); - } - /*! \brief Release the memory acquired by the container without deleting its entries stored inside - */ - void ReleaseMemory() { - delete[] data_; - data_ = nullptr; - slots_ = 0; - size_ = 0; - fib_shift_ = 63; - } - /*! - * \brief Create an empty container - * \param fib_shift The fib shift provided - * \param n_slots Number of slots required, should be power-of-two - * \return The object created - */ - static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { - ICHECK_GT(n_slots, uint64_t(SmallMapNode::kMaxSize)); - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(n_slots - 1); - Block* block = p->data_ = new Block[n_blocks]; - p->slots_ = n_slots - 1; - p->size_ = 0; - p->fib_shift_ = fib_shift; - for (uint64_t i = 0; i < n_blocks; ++i, ++block) { - std::fill(block->bytes, block->bytes + kBlockCap, uint8_t(kEmptySlot)); - } - return p; - } - /*! - * \brief Create an empty container with elements copying from another DenseMapNode - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(DenseMapNode* from) { - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(from->slots_); - p->data_ = new Block[n_blocks]; - p->slots_ = from->slots_; - p->size_ = from->size_; - p->fib_shift_ = from->fib_shift_; - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr_from = from->data_[bi].bytes; - KVType* data_ptr_from = reinterpret_cast(from->data_[bi].bytes + kBlockCap); - uint8_t* meta_ptr_to = p->data_[bi].bytes; - KVType* data_ptr_to = reinterpret_cast(p->data_[bi].bytes + kBlockCap); - for (int j = 0; j < kBlockCap; - ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) { - uint8_t& meta = *meta_ptr_to = *meta_ptr_from; - ICHECK(meta != kProtectedSlot); - if (meta != uint8_t(kEmptySlot)) { - new (data_ptr_to) KVType(*data_ptr_from); - } - } - } - return p; - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - DenseMapNode* map_node = static_cast(map->get()); - ListNode iter; - // Try to insert. If succeed, we simply return - if (map_node->TryInsert(kv.first, &iter)) { - iter.Val() = kv.second; - return; - } - ICHECK_GT(map_node->slots_, uint64_t(SmallMapNode::kMaxSize)); - // Otherwise, start rehash - ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2 + 2); - // Insert the given `kv` into the new hash map - InsertMaybeReHash(kv, &p); - uint64_t n_blocks = CalcNumBlocks(map_node->slots_); - // Then Insert data from the original block. - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr = map_node->data_[bi].bytes; - KVType* data_ptr = reinterpret_cast(map_node->data_[bi].bytes + kBlockCap); - for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { - uint8_t& meta = *meta_ptr; - if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { - meta = uint8_t(kEmptySlot); - KVType kv = std::move(*data_ptr); - InsertMaybeReHash(kv, &p); - } - } - } - map_node->ReleaseMemory(); - *map = p; - } - /*! - * \brief Check whether the hash table is full - * \return A boolean indicating whether hash table is full - */ - bool IsFull() const { return size_ + 1 > (slots_ + 1) * kMaxLoadFactor; } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { - for (++index; index <= slots_; ++index) { - if (!ListNode(index, this).IsEmpty()) { - return index; - } - } - return slots_ + 1; - } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { - while (index != 0) { - index -= 1; - if (!ListNode(index, this).IsEmpty()) { - return index; - } - } - return slots_ + 1; - } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); } - /*! \brief Construct from hash code */ - ListNode IndexFromHash(uint64_t hash_value) const { - return ListNode(FibHash(hash_value, fib_shift_), this); - } - /*! \brief Construct from hash code if the position is head of list */ - ListNode GetListHead(uint64_t hash_value) const { - ListNode node = IndexFromHash(hash_value); - return node.IsHead() ? node : ListNode(); - } - /*! \brief Construct the number of blocks in the hash table */ - static uint64_t CalcNumBlocks(uint64_t n_slots_m1) { - uint64_t n_slots = n_slots_m1 > 0 ? n_slots_m1 + 1 : 0; - return (n_slots + kBlockCap - 1) / kBlockCap; - } - /*! - * \brief Calculate the power-of-2 table size given the lower-bound of required capacity. - * \param cap The lower-bound of the required capacity - * \param fib_shift The result shift for Fibonacci Hashing - * \param n_slots The result number of slots - */ - static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) { - uint32_t shift = 64; - uint64_t slots = 1; - for (uint64_t c = cap; c; c >>= 1) { - shift -= 1; - slots <<= 1; - } - ICHECK_GT(slots, cap); - if (slots < cap * 2) { - *fib_shift = shift - 1; - *n_slots = slots << 1; - } else { - *fib_shift = shift; - *n_slots = slots; - } - } - /*! - * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table. - * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/. - * \param hash_value The raw hash value - * \param fib_shift The shift in Fibonacci Hashing - * \return An index calculated using Fibonacci Hashing - */ - static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) { - constexpr uint64_t coeff = 11400714819323198485ull; - return (coeff * hash_value) >> fib_shift; - } - /*! \brief The implicit in-place linked list used to index a chain */ - struct ListNode { - /*! \brief Construct None */ - ListNode() : index(0), block(nullptr) {} - /*! \brief Construct from position */ - ListNode(uint64_t index, const DenseMapNode* self) - : index(index), block(self->data_ + (index / kBlockCap)) {} - /*! \brief Metadata on the entry */ - uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); } - /*! \brief Data on the entry */ - KVType& Data() const { - return *(reinterpret_cast(block->bytes + kBlockCap + - (index % kBlockCap) * sizeof(KVType))); - } - /*! \brief Key on the entry */ - key_type& Key() const { return Data().first; } - /*! \brief Value on the entry */ - mapped_type& Val() const { return Data().second; } - /*! \brief If the entry is head of linked list */ - bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; } - /*! \brief If the entry is none */ - bool IsNone() const { return block == nullptr; } - /*! \brief If the entry is empty slot */ - bool IsEmpty() const { return Meta() == uint8_t(kEmptySlot); } - /*! \brief If the entry is protected slot */ - bool IsProtected() const { return Meta() == uint8_t(kProtectedSlot); } - /*! \brief Set the entry to be empty */ - void SetEmpty() const { Meta() = uint8_t(kEmptySlot); } - /*! \brief Set the entry to be protected */ - void SetProtected() const { Meta() = uint8_t(kProtectedSlot); } - /*! \brief Set the entry's jump to its next entry */ - void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; } - /*! \brief Construct a head of linked list in-place */ - void NewHead(KVType v) const { - Meta() = 0b00000000; - new (&Data()) KVType(std::move(v)); - } - /*! \brief Construct a tail of linked list in-place */ - void NewTail(KVType v) const { - Meta() = 0b10000000; - new (&Data()) KVType(std::move(v)); - } - /*! \brief If the entry has next entry on the linked list */ - bool HasNext() const { return kNextProbeLocation[Meta() & 0b01111111] != 0; } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapNode* self, uint8_t meta) { - uint64_t offset = kNextProbeLocation[meta & 0b01111111]; - if (offset == 0) { - index = 0; - block = nullptr; - return false; - } - index = (index + offset) & (self->slots_); - block = self->data_ + (index / kBlockCap); - return true; - } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapNode* self) { return MoveToNext(self, Meta()); } - /*! \brief Get the previous entry on the linked list */ - ListNode FindPrev(const DenseMapNode* self) const { - // start from the head of the linked list, which must exist - ListNode next = self->IndexFromHash(ObjectHash()(Key())); - // `prev` is always the previous item of `next` - ListNode prev = next; - for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) { - } - return prev; - } - /*! \brief Get the next empty jump */ - bool GetNextEmpty(const DenseMapNode* self, uint8_t* jump, ListNode* result) const { - for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { - ListNode candidate((index + kNextProbeLocation[idx]) & (self->slots_), self); - if (candidate.IsEmpty()) { - *jump = idx; - *result = candidate; - return true; - } - } - return false; - } - /*! \brief Index on the real array */ - uint64_t index; - /*! \brief Pointer to the actual block */ - Block* block; - }; - - protected: - /*! \brief fib shift in Fibonacci Hashing */ - uint32_t fib_shift_; - /*! \brief array of data blocks */ - Block* data_; - /* clang-format off */ - /*! \brief Candidates of probing distance */ - TVM_DLL static constexpr uint64_t kNextProbeLocation[kNumJumpDists] { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - // Quadratic probing with triangle numbers. See also: - // 1) https://en.wikipedia.org/wiki/Quadratic_probing - // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ - // 3) https://github.com/skarupke/flat_hash_map - 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, - 136, 153, 171, 190, 210, 231, 253, 276, 300, 325, - 351, 378, 406, 435, 465, 496, 528, 561, 595, 630, - 666, 703, 741, 780, 820, 861, 903, 946, 990, 1035, - 1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540, - 1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145, - 2211, 2278, 2346, 2415, 2485, 2556, 2628, - // larger triangle numbers - 8515, 19110, 42778, 96141, 216153, - 486591, 1092981, 2458653, 5532801, 12442566, - 27993903, 62983476, 141717030, 318844378, 717352503, - 1614057336, 3631522476, 8170957530, 18384510628, 41364789378, - 93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695, - 5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000, - 309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701, - 17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, 457381325854679626, - 1029107982097042876, 2315492959180353330, 5209859154120846435, - }; - /* clang-format on */ - friend class MapNode; -}; - -#define TVM_DISPATCH_MAP(base, var, body) \ - { \ - using TSmall = SmallMapNode*; \ - using TDense = DenseMapNode*; \ - uint64_t slots = base->slots_; \ - if (slots <= SmallMapNode::kMaxSize) { \ - TSmall var = static_cast(base); \ - body; \ - } else { \ - TDense var = static_cast(base); \ - body; \ - } \ - } - -#define TVM_DISPATCH_MAP_CONST(base, var, body) \ - { \ - using TSmall = const SmallMapNode*; \ - using TDense = const DenseMapNode*; \ - uint64_t slots = base->slots_; \ - if (slots <= SmallMapNode::kMaxSize) { \ - TSmall var = static_cast(base); \ - body; \ - } else { \ - TDense var = static_cast(base); \ - body; \ - } \ - } - -inline MapNode::iterator::pointer MapNode::iterator::operator->() const { - TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); -} - -inline MapNode::iterator& MapNode::iterator::operator++() { - TVM_DISPATCH_MAP_CONST(self, p, { - index = p->IncItr(index); - return *this; - }); -} - -inline MapNode::iterator& MapNode::iterator::operator--() { - TVM_DISPATCH_MAP_CONST(self, p, { - index = p->DecItr(index); - return *this; - }); -} - -inline size_t MapNode::count(const key_type& key) const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->count(key); }); -} - -inline const MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->at(key); }); -} - -inline MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) { - TVM_DISPATCH_MAP(this, p, { return p->at(key); }); -} - -inline MapNode::iterator MapNode::begin() const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->begin(); }); -} - -inline MapNode::iterator MapNode::end() const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->end(); }); -} - -inline MapNode::iterator MapNode::find(const MapNode::key_type& key) const { - TVM_DISPATCH_MAP_CONST(this, p, { return p->find(key); }); -} - -inline void MapNode::erase(const MapNode::iterator& position) { - TVM_DISPATCH_MAP(this, p, { return p->erase(position); }); -} - -#undef TVM_DISPATCH_MAP -#undef TVM_DISPATCH_MAP_CONST - -inline ObjectPtr MapNode::Empty() { return SmallMapNode::Empty(); } - -inline ObjectPtr MapNode::CopyFrom(MapNode* from) { - if (from->slots_ <= SmallMapNode::kMaxSize) { - return SmallMapNode::CopyFrom(static_cast(from)); - } else { - return DenseMapNode::CopyFrom(static_cast(from)); - } -} - -template -inline ObjectPtr MapNode::CreateFromRange(IterType first, IterType last) { - int64_t _cap = std::distance(first, last); - if (_cap < 0) { - return SmallMapNode::Empty(); - } - uint64_t cap = static_cast(_cap); - if (cap < SmallMapNode::kMaxSize) { - return SmallMapNode::CreateFromRange(cap, first, last); - } - uint32_t fib_shift; - uint64_t n_slots; - DenseMapNode::CalcTableSize(cap, &fib_shift, &n_slots); - ObjectPtr obj = DenseMapNode::Empty(fib_shift, n_slots); - for (; first != last; ++first) { - KVType kv(*first); - DenseMapNode::InsertMaybeReHash(kv, &obj); - } - return obj; -} - -inline void MapNode::InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { - constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize; - MapNode* base = static_cast(map->get()); - if (base->slots_ < kSmallMapMaxSize) { - SmallMapNode::InsertMaybeReHash(kv, map); - } else if (base->slots_ == kSmallMapMaxSize) { - if (base->size_ < base->slots_) { - SmallMapNode::InsertMaybeReHash(kv, map); - } else { - ObjectPtr new_map = MapNode::CreateFromRange(base->begin(), base->end()); - DenseMapNode::InsertMaybeReHash(kv, &new_map); - *map = std::move(new_map); - } - } else { - DenseMapNode::InsertMaybeReHash(kv, map); - } -} - -template <> -inline ObjectPtr make_object<>() = delete; - -#endif - -/*! - * \brief Map container of NodeRef->NodeRef in DSL graph. - * Map implements copy on write semantics, which means map is mutable - * but copy will happen when array is referenced in more than two places. - * - * operator[] only provide const acces, use Set to mutate the content. - * \tparam K The key NodeRef type. - * \tparam V The value NodeRef type. - */ -template ::value>::type, - typename = typename std::enable_if::value>::type> -class Map : public ObjectRef { - public: - using key_type = K; - using mapped_type = V; - class iterator; - /*! - * \brief default constructor - */ - Map() { data_ = MapNode::Empty(); } - /*! - * \brief move constructor - * \param other source - */ - Map(Map&& other) { data_ = std::move(other.data_); } - /*! - * \brief copy constructor - * \param other source - */ - Map(const Map& other) : ObjectRef(other.data_) {} - /*! - * \brief copy assign operator - * \param other The source of assignment - * \return reference to self. - */ - Map& operator=(Map&& other) { - data_ = std::move(other.data_); - return *this; - } - /*! - * \brief move assign operator - * \param other The source of assignment - * \return reference to self. - */ - Map& operator=(const Map& other) { - data_ = other.data_; - return *this; - } - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Map(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - Map(IterType begin, IterType end) { - data_ = MapNode::CreateFromRange(begin, end); - } - /*! - * \brief constructor from initializer list - * \param init The initalizer list - */ - Map(std::initializer_list> init) { - data_ = MapNode::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief constructor from unordered_map - * \param init The unordered_map - */ - template - Map(const std::unordered_map& init) { // NOLINT(*) - data_ = MapNode::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V at(const K& key) const { return DowncastNoCheck(GetMapNode()->at(key)); } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V operator[](const K& key) const { return this->at(key); } - /*! \return The size of the array */ - size_t size() const { - MapNode* n = GetMapNode(); - return n == nullptr ? 0 : n->size(); - } - /*! \return The number of elements of the key */ - size_t count(const K& key) const { - MapNode* n = GetMapNode(); - return n == nullptr ? 0 : GetMapNode()->count(key); - } - /*! \return whether array is empty */ - bool empty() const { return size() == 0; } - /*! \brief Release reference to all the elements */ - void clear() { - MapNode* n = GetMapNode(); - if (n != nullptr) { - data_ = MapNode::Empty(); - } - } - /*! - * \brief set the Map. - * \param key The index key. - * \param value The value to be setted. - */ - void Set(const K& key, const V& value) { - CopyOnWrite(); - MapNode::InsertMaybeReHash(MapNode::KVType(key, value), &data_); - } - /*! \return begin iterator */ - iterator begin() const { return iterator(GetMapNode()->begin()); } - /*! \return end iterator */ - iterator end() const { return iterator(GetMapNode()->end()); } - /*! \return find the key and returns the associated iterator */ - iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); } - - void erase(const K& key) { CopyOnWrite()->erase(key); } - - /*! - * \brief copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which ganrantees to be unique) - */ - MapNode* CopyOnWrite() { - if (data_.get() == nullptr) { - data_ = MapNode::Empty(); - } else if (!data_.unique()) { - data_ = MapNode::CopyFrom(GetMapNode()); - } - return GetMapNode(); - } - /*! \brief specify container node */ - using ContainerType = MapNode; - - /*! \brief Iterator of the hash map */ - class iterator { - public: - using iterator_category = std::bidirectional_iterator_tag; - using difference_type = int64_t; - using value_type = const std::pair; - using pointer = value_type*; - using reference = value_type; - - iterator() : itr() {} - - /*! \brief Compare iterators */ - bool operator==(const iterator& other) const { return itr == other.itr; } - /*! \brief Compare iterators */ - bool operator!=(const iterator& other) const { return itr != other.itr; } - /*! \brief De-reference iterators is not allowed */ - pointer operator->() const = delete; - /*! \brief De-reference iterators */ - reference operator*() const { - auto& kv = *itr; - return std::make_pair(DowncastNoCheck(kv.first), DowncastNoCheck(kv.second)); - } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator& operator++() { - ++itr; - return *this; - } - /*! \brief Suffix self increment */ - iterator operator++(int) { - iterator copy = *this; - ++(*this); - return copy; - } - - private: - iterator(const MapNode::iterator& itr) // NOLINT(*) - : itr(itr) {} - - template - friend class Map; - - MapNode::iterator itr; - }; - - private: - /*! \brief Return data_ as type of pointer of MapNode */ - MapNode* GetMapNode() const { return static_cast(data_.get()); } -}; - -/*! - * \brief Merge two Maps. - * \param lhs the first Map to merge. - * \param rhs the second Map to merge. - * @return The merged Array. Original Maps are kept unchanged. - */ -template ::value>::type, - typename = typename std::enable_if::value>::type> -inline Map Merge(Map lhs, const Map& rhs) { - for (const auto& p : rhs) { - lhs.Set(p.first, p.second); - } - return std::move(lhs); -} - -} // namespace runtime - -// expose the functions to the root namespace. -using runtime::Array; -using runtime::ArrayNode; -using runtime::Downcast; -using runtime::IterAdapter; -using runtime::make_object; -using runtime::Map; -using runtime::MapNode; -using runtime::Object; -using runtime::ObjectEqual; -using runtime::ObjectHash; -using runtime::ObjectPtr; -using runtime::ObjectPtrEqual; -using runtime::ObjectPtrHash; -using runtime::ObjectRef; -using runtime::Optional; -using runtime::String; -using runtime::StringObj; -constexpr runtime::NullOptType NullOpt{}; -} // namespace tvm - -namespace std { - -template <> -struct hash<::tvm::runtime::String> { - std::size_t operator()(const ::tvm::runtime::String& str) const { - return ::tvm::runtime::String::HashBytes(str.data(), str.size()); - } -}; -} // namespace std - -#endif // TVM_RUNTIME_CONTAINER_H_ diff --git a/include/tvm/runtime/container/adt.h b/include/tvm/runtime/container/adt.h new file mode 100644 index 000000000000..20c4f796d741 --- /dev/null +++ b/include/tvm/runtime/container/adt.h @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/adt.h + * \brief Runtime ADT container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_ADT_H_ +#define TVM_RUNTIME_CONTAINER_ADT_H_ + +#include +#include + +#include "./base.h" + +namespace tvm { +namespace runtime { + +/*! \brief An object representing a structure or enumeration. */ +class ADTObj : public Object, public InplaceArrayBase { + public: + /*! \brief The tag representing the constructor used. */ + int32_t tag; + /*! \brief Number of fields in the ADT object. */ + uint32_t size; + // The fields of the structure follows directly in memory. + + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeADT; + static constexpr const char* _type_key = "runtime.ADT"; + TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); + + private: + /*! + * \return The number of elements in the array. + */ + size_t GetSize() const { return size; } + + /*! + * \brief Initialize the elements in the array. + * + * \tparam Iterator Iterator type of the array. + * \param begin The begin iterator. + * \param end The end iterator. + */ + template + void Init(Iterator begin, Iterator end) { + size_t num_elems = std::distance(begin, end); + this->size = 0; + auto it = begin; + for (size_t i = 0; i < num_elems; ++i) { + InplaceArrayBase::EmplaceInit(i, *it++); + // Only increment size after the initialization succeeds + this->size++; + } + } + + friend class ADT; + friend InplaceArrayBase; +}; + +/*! \brief reference to algebraic data type objects. */ +class ADT : public ObjectRef { + public: + /*! + * \brief construct an ADT object reference. + * \param tag The tag of the ADT object. + * \param fields The fields of the ADT object. + * \return The constructed ADT object reference. + */ + ADT(int32_t tag, std::vector fields) : ADT(tag, fields.begin(), fields.end()){}; + + /*! + * \brief construct an ADT object reference. + * \param tag The tag of the ADT object. + * \param begin The begin iterator to the start of the fields array. + * \param end The end iterator to the end of the fields array. + * \return The constructed ADT object reference. + */ + template + ADT(int32_t tag, Iterator begin, Iterator end) { + size_t num_elems = std::distance(begin, end); + auto ptr = make_inplace_array_object(num_elems); + ptr->tag = tag; + ptr->Init(begin, end); + data_ = std::move(ptr); + } + + /*! + * \brief construct an ADT object reference. + * \param tag The tag of the ADT object. + * \param init The initializer list of fields. + * \return The constructed ADT object reference. + */ + ADT(int32_t tag, std::initializer_list init) : ADT(tag, init.begin(), init.end()){}; + + /*! + * \brief Access element at index. + * + * \param idx The array index + * \return const ObjectRef + */ + const ObjectRef& operator[](size_t idx) const { return operator->()->operator[](idx); } + + /*! + * \brief Return the ADT tag. + */ + int32_t tag() const { return operator->()->tag; } + + /*! + * \brief Return the number of fields. + */ + size_t size() const { return operator->()->size; } + + /*! + * \brief Construct a tuple object. + * + * \tparam Args Type params of tuple feilds. + * \param args Tuple fields. + * \return ADT The tuple object reference. + */ + template + static ADT Tuple(Args&&... args) { + return ADT(0, std::forward(args)...); + } + + TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj); +}; +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_CONTAINER_ADT_H_ diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h new file mode 100644 index 000000000000..8830653da88c --- /dev/null +++ b/include/tvm/runtime/container/array.h @@ -0,0 +1,739 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/array.h + * \brief Runtime Array container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_ARRAY_H_ +#define TVM_RUNTIME_CONTAINER_ARRAY_H_ + +#include +#include +#include +#include + +#include "./base.h" + +namespace tvm { +namespace runtime { + +/*! \brief array node content in array */ +class ArrayNode : public Object, public InplaceArrayBase { + public: + /*! \return The size of the array */ + size_t size() const { return this->size_; } + + /*! + * \brief Read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const ObjectRef at(int64_t i) const { return this->operator[](i); } + + /*! \return begin constant iterator */ + const ObjectRef* begin() const { return static_cast(InplaceArrayBase::AddressOf(0)); } + + /*! \return end constant iterator */ + const ObjectRef* end() const { return begin() + size_; } + + /*! \brief Release reference to all the elements */ + void clear() { ShrinkBy(size_); } + + /*! + * \brief Set i-th element of the array in-place + * \param i The index + * \param item The value to be set + */ + void SetItem(int64_t i, ObjectRef item) { this->operator[](i) = std::move(item); } + + /*! + * \brief Constructs a container and copy from another + * \param cap The capacity of the container + * \param from Source of the copy + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr CopyFrom(int64_t cap, ArrayNode* from) { + int64_t size = from->size_; + ICHECK_GE(cap, size) << "ValueError: not enough capacity"; + ObjectPtr p = ArrayNode::Empty(cap); + ObjectRef* write = p->MutableBegin(); + ObjectRef* read = from->MutableBegin(); + // To ensure exception safety, size is only incremented after the initialization succeeds + for (int64_t& i = p->size_ = 0; i < size; ++i) { + new (write++) ObjectRef(*read++); + } + return p; + } + + /*! + * \brief Constructs a container and move from another + * \param cap The capacity of the container + * \param from Source of the move + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr MoveFrom(int64_t cap, ArrayNode* from) { + int64_t size = from->size_; + ICHECK_GE(cap, size) << "ValueError: not enough capacity"; + ObjectPtr p = ArrayNode::Empty(cap); + ObjectRef* write = p->MutableBegin(); + ObjectRef* read = from->MutableBegin(); + // To ensure exception safety, size is only incremented after the initialization succeeds + for (int64_t& i = p->size_ = 0; i < size; ++i) { + new (write++) ObjectRef(std::move(*read++)); + } + from->size_ = 0; + return p; + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr CreateRepeated(int64_t n, const ObjectRef& val) { + ObjectPtr p = ArrayNode::Empty(n); + ObjectRef* itr = p->MutableBegin(); + for (int64_t& i = p->size_ = 0; i < n; ++i) { + new (itr++) ObjectRef(val); + } + return p; + } + + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeArray; + static constexpr const char* _type_key = "Array"; + TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); + + private: + /*! \return Size of initialized memory, used by InplaceArrayBase. */ + size_t GetSize() const { return this->size_; } + + /*! \return begin mutable iterator */ + ObjectRef* MutableBegin() const { + return static_cast(InplaceArrayBase::AddressOf(0)); + } + + /*! \return end mutable iterator */ + ObjectRef* MutableEnd() const { return MutableBegin() + size_; } + + /*! + * \brief Create an ArrayNode with the given capacity. + * \param n Required capacity + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr Empty(int64_t n = kInitSize) { + ICHECK_GE(n, 0); + ObjectPtr p = make_inplace_array_object(n); + p->capacity_ = n; + p->size_ = 0; + return p; + } + + /*! + * \brief Inplace-initialize the elements starting idx from [first, last) + * \param idx The starting point + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return Self + */ + template + ArrayNode* InitRange(int64_t idx, IterType first, IterType last) { + ObjectRef* itr = MutableBegin() + idx; + for (; first != last; ++first) { + ObjectRef ref = *first; + new (itr++) ObjectRef(std::move(ref)); + } + return this; + } + + /*! + * \brief Move elements from right to left, requires src_begin > dst + * \param dst Destination + * \param src_begin The start point of copy (inclusive) + * \param src_end The end point of copy (exclusive) + * \return Self + */ + ArrayNode* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { + ObjectRef* from = MutableBegin() + src_begin; + ObjectRef* to = MutableBegin() + dst; + while (src_begin++ != src_end) { + *to++ = std::move(*from++); + } + return this; + } + + /*! + * \brief Move elements from left to right, requires src_begin < dst + * \param dst Destination + * \param src_begin The start point of move (inclusive) + * \param src_end The end point of move (exclusive) + * \return Self + */ + ArrayNode* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { + ObjectRef* from = MutableBegin() + src_end; + ObjectRef* to = MutableBegin() + (src_end - src_begin + dst); + while (src_begin++ != src_end) { + *--to = std::move(*--from); + } + return this; + } + + /*! + * \brief Enlarges the size of the array + * \param delta Size enlarged, should be positive + * \param val Default value + * \return Self + */ + ArrayNode* EnlargeBy(int64_t delta, const ObjectRef& val = ObjectRef(nullptr)) { + ObjectRef* itr = MutableEnd(); + while (delta-- > 0) { + new (itr++) ObjectRef(val); + ++size_; + } + return this; + } + + /*! + * \brief Shrinks the size of the array + * \param delta Size shrinked, should be positive + * \return Self + */ + ArrayNode* ShrinkBy(int64_t delta) { + ObjectRef* itr = MutableEnd(); + while (delta-- > 0) { + (--itr)->ObjectRef::~ObjectRef(); + --size_; + } + return this; + } + + /*! \brief Number of elements used */ + int64_t size_; + + /*! \brief Number of elements allocated */ + int64_t capacity_; + + /*! \brief Initial size of ArrayNode */ + static constexpr int64_t kInitSize = 4; + + /*! \brief Expansion factor of the Array */ + static constexpr int64_t kIncFactor = 2; + + // CRTP parent class + friend InplaceArrayBase; + + // Reference class + template + friend class Array; + + // To specialize make_object + friend ObjectPtr make_object<>(); +}; + +/*! + * \brief Array, container representing a contigious sequence of ObjectRefs. + * + * Array implements in-place copy-on-write semantics. + * + * As in typical copy-on-write, a method which would typically mutate the array + * instead opaquely copies the underlying container, and then acts on its copy. + * + * If the array has reference count equal to one, we directly update the + * container in place without copying. This is optimization is sound because + * when the reference count is equal to one this reference is guranteed to be + * the sole pointer to the container. + * + * + * operator[] only provides const access, use Set to mutate the content. + * \tparam T The content ObjectRef type. + */ +template ::value>::type> +class Array : public ObjectRef { + public: + using value_type = T; + // constructors + /*! + * \brief default constructor + */ + Array() { data_ = ArrayNode::Empty(); } + + /*! + * \brief move constructor + * \param other source + */ + Array(Array&& other) : ObjectRef() { // NOLINT(*) + data_ = std::move(other.data_); + } + + /*! + * \brief copy constructor + * \param other source + */ + Array(const Array& other) : ObjectRef() { // NOLINT(*) + data_ = other.data_; + } + + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Array(ObjectPtr n) : ObjectRef(n) {} + + /*! + * \brief Constructor from iterator + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template + Array(IterType first, IterType last) { + Assign(first, last); + } + + /*! + * \brief constructor from initializer list + * \param init The initializer list + */ + Array(std::initializer_list init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief constructor from vector + * \param init The vector + */ + Array(const std::vector& init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + */ + explicit Array(const size_t n, const T& val) { data_ = ArrayNode::CreateRepeated(n, val); } + + /*! + * \brief move assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array& operator=(Array&& other) { + data_ = std::move(other.data_); + return *this; + } + + /*! + * \brief copy assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array& operator=(const Array& other) { + data_ = other.data_; + return *this; + } + + public: + // iterators + struct ValueConverter { + using ResultType = T; + static T convert(const ObjectRef& n) { return DowncastNoCheck(n); } + }; + + using iterator = IterAdapter; + using reverse_iterator = ReverseIterAdapter; + + /*! \return begin iterator */ + iterator begin() const { return iterator(GetArrayNode()->begin()); } + + /*! \return end iterator */ + iterator end() const { return iterator(GetArrayNode()->end()); } + + /*! \return rbegin iterator */ + reverse_iterator rbegin() const { + // ArrayNode::end() is never nullptr + return reverse_iterator(GetArrayNode()->end() - 1); + } + + /*! \return rend iterator */ + reverse_iterator rend() const { + // ArrayNode::begin() is never nullptr + return reverse_iterator(GetArrayNode()->begin() - 1); + } + + public: + // const methods in std::vector + /*! + * \brief Immutably read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const T operator[](int64_t i) const { + ArrayNode* p = GetArrayNode(); + ICHECK(p != nullptr) << "ValueError: cannot index a null array"; + ICHECK(0 <= i && i < p->size_) + << "IndexError: indexing " << i << " on an array of size " << p->size_; + return DowncastNoCheck(*(p->begin() + i)); + } + + /*! \return The size of the array */ + size_t size() const { + ArrayNode* p = GetArrayNode(); + return p == nullptr ? 0 : GetArrayNode()->size_; + } + + /*! \return The capacity of the array */ + size_t capacity() const { + ArrayNode* p = GetArrayNode(); + return p == nullptr ? 0 : GetArrayNode()->capacity_; + } + + /*! \return Whether array is empty */ + bool empty() const { return size() == 0; } + + /*! \return The first element of the array */ + const T front() const { + ArrayNode* p = GetArrayNode(); + ICHECK(p != nullptr) << "ValueError: cannot index a null array"; + ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; + return DowncastNoCheck(*(p->begin())); + } + + /*! \return The last element of the array */ + const T back() const { + ArrayNode* p = GetArrayNode(); + ICHECK(p != nullptr) << "ValueError: cannot index a null array"; + ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; + return DowncastNoCheck(*(p->end() - 1)); + } + + public: + // mutation in std::vector, implements copy-on-write + + /*! + * \brief push a new item to the back of the list + * \param item The item to be pushed. + */ + void push_back(const T& item) { + ArrayNode* p = CopyOnWrite(1); + p->EmplaceInit(p->size_++, item); + } + + /*! + * \brief Insert an element into the given position + * \param position An iterator pointing to the insertion point + * \param val The element to insert + */ + void insert(iterator position, const T& val) { + ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + auto addr = CopyOnWrite(1) // + ->EnlargeBy(1) // + ->MoveElementsRight(idx + 1, idx, size) // + ->MutableBegin(); + new (addr + idx) ObjectRef(val); + } + + /*! + * \brief Insert a range of elements into the given position + * \param position An iterator pointing to the insertion point + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + template + void insert(iterator position, IterType first, IterType last) { + if (first == last) { + return; + } + ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + int64_t numel = std::distance(first, last); + CopyOnWrite(numel) + ->EnlargeBy(numel) + ->MoveElementsRight(idx + numel, idx, size) + ->InitRange(idx, first, last); + } + + /*! \brief Remove the last item of the list */ + void pop_back() { + ICHECK(data_ != nullptr) << "ValueError: cannot pop_back because array is null"; + int64_t size = GetArrayNode()->size_; + ICHECK_GT(size, 0) << "ValueError: cannot pop_back because array is empty"; + CopyOnWrite()->ShrinkBy(1); + } + + /*! + * \brief Erase an element on the given position + * \param position An iterator pointing to the element to be erased + */ + void erase(iterator position) { + ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; + int64_t st = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + ICHECK(0 <= st && st < size) << "ValueError: cannot erase at index " << st + << ", because Array size is " << size; + CopyOnWrite() // + ->MoveElementsLeft(st, st + 1, size) // + ->ShrinkBy(1); + } + + /*! + * \brief Erase a given range of elements + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + void erase(iterator first, iterator last) { + if (first == last) { + return; + } + ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; + int64_t size = GetArrayNode()->size_; + int64_t st = std::distance(begin(), first); + int64_t ed = std::distance(begin(), last); + ICHECK_LT(st, ed) << "ValueError: cannot erase array in range [" << st << ", " << ed << ")"; + ICHECK(0 <= st && st <= size && 0 <= ed && ed <= size) + << "ValueError: cannot erase array in range [" << st << ", " << ed << ")" + << ", because array size is " << size; + CopyOnWrite() // + ->MoveElementsLeft(st, ed, size) // + ->ShrinkBy(ed - st); + } + + /*! + * \brief Resize the array. + * \param n The new size. + */ + void resize(int64_t n) { + ICHECK_GE(n, 0) << "ValueError: cannot resize an Array to negative size"; + if (data_ == nullptr) { + SwitchContainer(n); + return; + } + int64_t size = GetArrayNode()->size_; + if (size < n) { + CopyOnWrite(n - size)->EnlargeBy(n - size); + } else if (size > n) { + CopyOnWrite()->ShrinkBy(size - n); + } + } + + /*! + * \brief Make sure the list has the capacity of at least n + * \param n lower bound of the capacity + */ + void reserve(int64_t n) { + if (data_ == nullptr || n > GetArrayNode()->capacity_) { + SwitchContainer(n); + } + } + + /*! \brief Release reference to all the elements */ + void clear() { + if (data_ != nullptr) { + ArrayNode* p = CopyOnWrite(); + p->clear(); + } + } + + public: + // Array's own methods + + /*! + * \brief set i-th element of the array. + * \param i The index + * \param value The value to be setted. + */ + void Set(int64_t i, T value) { + ArrayNode* p = this->CopyOnWrite(); + ICHECK(0 <= i && i < p->size_) + << "IndexError: indexing " << i << " on an array of size " << p->size_; + *(p->MutableBegin() + i) = std::move(value); + } + + /*! \return The underlying ArrayNode */ + ArrayNode* GetArrayNode() const { return static_cast(data_.get()); } + + /*! + * \brief Helper function to apply fmutate to mutate an array. + * \param fmutate The transformation function T -> T. + * \tparam F the type of the mutation function. + * \note This function performs copy on write optimization. + */ + template + void MutateByApply(F fmutate) { + if (data_ == nullptr) { + return; + } + struct StackFrame { + ArrayNode* p; + ObjectRef* itr; + int64_t i; + int64_t size; + }; + std::unique_ptr s = std::make_unique(); + s->p = GetArrayNode(); + s->itr = s->p->MutableBegin(); + s->i = 0; + s->size = s->p->size_; + if (!data_.unique()) { + // Loop invariant: keeps iterating when + // 1) data is not unique + // 2) no elements are actually mutated yet + for (; s->i < s->size; ++s->i, ++s->itr) { + T new_elem = fmutate(DowncastNoCheck(*s->itr)); + // do nothing when there is no mutation + if (new_elem.same_as(*s->itr)) { + continue; + } + // loop invariant breaks when the first real mutation happens + // we copy the elements into a new unique array + ObjectPtr copy = ArrayNode::CopyFrom(s->p->capacity_, s->p); + s->itr = copy->MutableBegin() + (s->i++); + *s->itr++ = std::move(new_elem); + data_ = std::move(copy); + // make sure `data_` is unique and break + break; + } + } + // when execution comes to this line, it is guaranteed that either + // 1) i == size + // or 2) data_.unique() is true + for (; s->i < s->size; ++s->i, ++s->itr) { + *s->itr = std::move(fmutate(std::move(DowncastNoCheck(std::move(*s->itr))))); + } + } + + /*! + * \brief reset the array to content from iterator. + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template + void Assign(IterType first, IterType last) { + int64_t cap = std::distance(first, last); + ICHECK_GE(cap, 0) << "ValueError: cannot construct an Array of negative size"; + ArrayNode* p = GetArrayNode(); + if (p != nullptr && data_.unique() && p->capacity_ >= cap) { + // do not have to make new space + p->clear(); + } else { + // create new space + data_ = ArrayNode::Empty(cap); + p = GetArrayNode(); + } + // To ensure exception safety, size is only incremented after the initialization succeeds + ObjectRef* itr = p->MutableBegin(); + for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) { + new (itr) ObjectRef(*first); + } + } + + /*! + * \brief Copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which ganrantees to be unique) + */ + ArrayNode* CopyOnWrite() { + if (data_ == nullptr) { + return SwitchContainer(ArrayNode::kInitSize); + } + if (!data_.unique()) { + return SwitchContainer(capacity()); + } + return static_cast(data_.get()); + } + + /*! \brief specify container node */ + using ContainerType = ArrayNode; + + private: + /*! + * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. + * \param reserve_extra Number of extra slots needed + * \return ArrayNode pointer to the unique copy + */ + ArrayNode* CopyOnWrite(int64_t reserve_extra) { + ArrayNode* p = GetArrayNode(); + if (p == nullptr) { + // necessary to get around the constexpr address issue before c++17 + const int64_t kInitSize = ArrayNode::kInitSize; + return SwitchContainer(std::max(kInitSize, reserve_extra)); + } + if (p->capacity_ >= p->size_ + reserve_extra) { + return CopyOnWrite(); + } + int64_t cap = p->capacity_ * ArrayNode::kIncFactor; + cap = std::max(cap, p->size_ + reserve_extra); + return SwitchContainer(cap); + } + + /*! + * \brief Move or copy the ArrayNode to new address with the given capacity + * \param capacity The capacity requirement of the new address + */ + ArrayNode* SwitchContainer(int64_t capacity) { + if (data_ == nullptr) { + data_ = ArrayNode::Empty(capacity); + } else if (data_.unique()) { + data_ = ArrayNode::MoveFrom(capacity, GetArrayNode()); + } else { + data_ = ArrayNode::CopyFrom(capacity, GetArrayNode()); + } + return static_cast(data_.get()); + } +}; + +/*! + * \brief Concat two Arrays. + * \param lhs first Array to be concatenated. + * \param rhs second Array to be concatenated. + * \return The concatenated Array. Original Arrays are kept unchanged. + */ +template ::value>::type> +inline Array Concat(Array lhs, const Array& rhs) { + for (const auto& x : rhs) { + lhs.push_back(x); + } + return std::move(lhs); +} + +// Specialize make_object to make sure it is correct. +template <> +inline ObjectPtr make_object() { + return ArrayNode::Empty(); +} + +} // namespace runtime + +// expose the functions to the root namespace. +using runtime::Array; +using runtime::ArrayNode; +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_ARRAY_H_ diff --git a/include/tvm/runtime/container/base.h b/include/tvm/runtime/container/base.h new file mode 100644 index 000000000000..4112c213d6f0 --- /dev/null +++ b/include/tvm/runtime/container/base.h @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/base.h + * \brief Base utilities for common POD(plain old data) container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_BASE_H_ +#define TVM_RUNTIME_CONTAINER_BASE_H_ + +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! \brief String-aware ObjectRef equal functor */ +struct ObjectHash { + /*! + * \brief Calculate the hash code of an ObjectRef + * \param a The given ObjectRef + * \return Hash code of a, string hash for strings and pointer address otherwise. + */ + size_t operator()(const ObjectRef& a) const; +}; + +/*! \brief String-aware ObjectRef hash functor */ +struct ObjectEqual { + /*! + * \brief Check if the two ObjectRef are equal + * \param a One ObjectRef + * \param b The other ObjectRef + * \return String equality if both are strings, pointer address equality otherwise. + */ + bool operator()(const ObjectRef& a, const ObjectRef& b) const; +}; + +/*! + * \brief Base template for classes with array like memory layout. + * + * It provides general methods to access the memory. The memory + * layout is ArrayType + [ElemType]. The alignment of ArrayType + * and ElemType is handled by the memory allocator. + * + * \tparam ArrayType The array header type, contains object specific metadata. + * \tparam ElemType The type of objects stored in the array right after + * ArrayType. + * + * \code + * // Example usage of the template to define a simple array wrapper + * class ArrayObj : public InplaceArrayBase { + * public: + * // Wrap EmplaceInit to initialize the elements + * template + * void Init(Iterator begin, Iterator end) { + * size_t num_elems = std::distance(begin, end); + * auto it = begin; + * this->size = 0; + * for (size_t i = 0; i < num_elems; ++i) { + * InplaceArrayBase::EmplaceInit(i, *it++); + * this->size++; + * } + * } + * } + * + * void test_function() { + * vector fields; + * auto ptr = make_inplace_array_object(fields.size()); + * ptr->Init(fields.begin(), fields.end()); + * + * // Access the 0th element in the array. + * assert(ptr->operator[](0) == fields[0]); + * } + * + * \endcode + */ +template +class InplaceArrayBase { + public: + /*! + * \brief Access element at index + * \param idx The index of the element. + * \return Const reference to ElemType at the index. + */ + const ElemType& operator[](size_t idx) const { + size_t size = Self()->GetSize(); + ICHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n"; + return *(reinterpret_cast(AddressOf(idx))); + } + + /*! + * \brief Access element at index + * \param idx The index of the element. + * \return Reference to ElemType at the index. + */ + ElemType& operator[](size_t idx) { + size_t size = Self()->GetSize(); + ICHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n"; + return *(reinterpret_cast(AddressOf(idx))); + } + + /*! + * \brief Destroy the Inplace Array Base object + */ + ~InplaceArrayBase() { + if (!(std::is_standard_layout::value && std::is_trivial::value)) { + size_t size = Self()->GetSize(); + for (size_t i = 0; i < size; ++i) { + ElemType* fp = reinterpret_cast(AddressOf(i)); + fp->ElemType::~ElemType(); + } + } + } + + protected: + /*! + * \brief Construct a value in place with the arguments. + * + * \tparam Args Type parameters of the arguments. + * \param idx Index of the element. + * \param args Arguments to construct the new value. + * + * \note Please make sure ArrayType::GetSize returns 0 before first call of + * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds. + */ + template + void EmplaceInit(size_t idx, Args&&... args) { + void* field_ptr = AddressOf(idx); + new (field_ptr) ElemType(std::forward(args)...); + } + + /*! + * \brief Return the self object for the array. + * + * \return Pointer to ArrayType. + */ + inline ArrayType* Self() const { + return static_cast(const_cast(this)); + } + + /*! + * \brief Return the raw pointer to the element at idx. + * + * \param idx The index of the element. + * \return Raw pointer to the element. + */ + void* AddressOf(size_t idx) const { + static_assert( + alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, + "The size and alignment of ArrayType should respect " + "ElemType's alignment."); + + size_t kDataStart = sizeof(ArrayType); + ArrayType* self = Self(); + char* data_start = reinterpret_cast(self) + kDataStart; + return data_start + idx * sizeof(ElemType); + } +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template +class IterAdapter { + public: + using difference_type = typename std::iterator_traits::difference_type; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; + using iterator_category = typename std::iterator_traits::iterator_category; + + explicit IterAdapter(TIter iter) : iter_(iter) {} + IterAdapter& operator++() { + ++iter_; + return *this; + } + IterAdapter& operator--() { + --iter_; + return *this; + } + IterAdapter operator++(int) { + IterAdapter copy = *this; + ++iter_; + return copy; + } + IterAdapter operator--(int) { + IterAdapter copy = *this; + --iter_; + return copy; + } + + IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } + + IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); } + + template + typename std::enable_if::value, + typename T::difference_type>::type inline + operator-(const IterAdapter& rhs) const { + return iter_ - rhs.iter_; + } + + bool operator==(IterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(IterAdapter other) const { return !(*this == other); } + const value_type operator*() const { return Converter::convert(*iter_); } + + private: + TIter iter_; +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template +class ReverseIterAdapter { + public: + using difference_type = typename std::iterator_traits::difference_type; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; // NOLINT(*) + using iterator_category = typename std::iterator_traits::iterator_category; + + explicit ReverseIterAdapter(TIter iter) : iter_(iter) {} + ReverseIterAdapter& operator++() { + --iter_; + return *this; + } + ReverseIterAdapter& operator--() { + ++iter_; + return *this; + } + ReverseIterAdapter& operator++(int) { + ReverseIterAdapter copy = *this; + --iter_; + return copy; + } + ReverseIterAdapter& operator--(int) { + ReverseIterAdapter copy = *this; + ++iter_; + return copy; + } + ReverseIterAdapter operator+(difference_type offset) const { + return ReverseIterAdapter(iter_ - offset); + } + + template + typename std::enable_if::value, + typename T::difference_type>::type inline + operator-(const ReverseIterAdapter& rhs) const { + return rhs.iter_ - iter_; + } + + bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(ReverseIterAdapter other) const { return !(*this == other); } + const value_type operator*() const { return Converter::convert(*iter_); } + + private: + TIter iter_; +}; + +} // namespace runtime + +// expose the functions to the root namespace. +using runtime::Downcast; +using runtime::IterAdapter; +using runtime::make_object; +using runtime::Object; +using runtime::ObjectEqual; +using runtime::ObjectHash; +using runtime::ObjectPtr; +using runtime::ObjectPtrEqual; +using runtime::ObjectPtrHash; +using runtime::ObjectRef; +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_BASE_H_ diff --git a/include/tvm/runtime/container/closure.h b/include/tvm/runtime/container/closure.h new file mode 100644 index 000000000000..a280d1ada7a9 --- /dev/null +++ b/include/tvm/runtime/container/closure.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/closure.h + * \brief Runtime Closure container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_CLOSURE_H_ +#define TVM_RUNTIME_CONTAINER_CLOSURE_H_ + +#include "./base.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief An object representing a closure. This object is used by both the + * Relay VM and interpreter. + */ +class ClosureObj : public Object { + public: + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeClosure; + static constexpr const char* _type_key = "runtime.Closure"; + TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object); +}; + +/*! \brief reference to closure. */ +class Closure : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj); +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_CLOSURE_H_ diff --git a/include/tvm/runtime/container/map.h b/include/tvm/runtime/container/map.h new file mode 100644 index 000000000000..671e38b83581 --- /dev/null +++ b/include/tvm/runtime/container/map.h @@ -0,0 +1,1441 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/map.h + * \brief Runtime Map container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_MAP_H_ +#define TVM_RUNTIME_CONTAINER_MAP_H_ + +#ifndef USE_FALLBACK_STL_MAP +#define USE_FALLBACK_STL_MAP 0 +#endif + +#include +#include +#include + +#include "./base.h" + +namespace tvm { +namespace runtime { + +#if (USE_FALLBACK_STL_MAP != 0) + +/*! \brief Shared content of all specializations of hash map */ +class MapNode : public Object { + public: + /*! \brief Type of the keys in the hash map */ + using key_type = ObjectRef; + /*! \brief Type of the values in the hash map */ + using mapped_type = ObjectRef; + /*! \brief Type of the actual underlying container */ + using ContainerType = std::unordered_map; + /*! \brief Iterator class */ + using iterator = ContainerType::iterator; + /*! \brief Iterator class */ + using const_iterator = ContainerType::const_iterator; + /*! \brief Type of value stored in the hash map */ + using KVType = ContainerType::value_type; + + static_assert(std::is_standard_layout::value, "KVType is not standard layout"); + static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); + + static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; + static constexpr const char* _type_key = "Map"; + TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); + + /*! + * \brief Number of elements in the SmallMapNode + * \return The result + */ + size_t size() const { return data_.size(); } + /*! + * \brief Count the number of times a key exists in the hash map + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const { return data_.count(key); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { return data_.at(key); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { return data_.at(key); } + /*! \return begin iterator */ + iterator begin() { return data_.begin(); } + /*! \return const begin iterator */ + const_iterator begin() const { return data_.begin(); } + /*! \return end iterator */ + iterator end() { return data_.end(); } + /*! \return end iterator */ + const_iterator end() const { return data_.end(); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + const_iterator find(const key_type& key) const { return data_.find(key); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) { return data_.find(key); } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { data_.erase(position); } + /*! + * \brief Erase the entry associated with the key, do nothing if not exists + * \param key The indexing key + */ + void erase(const key_type& key) { data_.erase(key); } + /*! + * \brief Create an empty container + * \return The object created + */ + static ObjectPtr Empty() { return make_object(); } + + protected: + /*! + * \brief Create the map using contents from the given iterators. + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return ObjectPtr to the map created + */ + template + static ObjectPtr CreateFromRange(IterType first, IterType last) { + ObjectPtr p = make_object(); + p->data_ = ContainerType(first, last); + return p; + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + MapNode* map_node = static_cast(map->get()); + map_node->data_[kv.first] = kv.second; + } + /*! + * \brief Create an empty container with elements copying from another MapNode + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(MapNode* from) { + ObjectPtr p = make_object(); + p->data_ = ContainerType(from->data_.begin(), from->data_.end()); + return p; + } + /*! \brief The real container storing data */ + ContainerType data_; + template + friend class Map; +}; + +#else + +/*! \brief Shared content of all specializations of hash map */ +class MapNode : public Object { + public: + /*! \brief Type of the keys in the hash map */ + using key_type = ObjectRef; + /*! \brief Type of the values in the hash map */ + using mapped_type = ObjectRef; + /*! \brief Type of value stored in the hash map */ + using KVType = std::pair; + /*! \brief Iterator class */ + class iterator; + + static_assert(std::is_standard_layout::value, "KVType is not standard layout"); + static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect"); + + static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; + static constexpr const char* _type_key = "Map"; + TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); + + /*! + * \brief Number of elements in the SmallMapNode + * \return The result + */ + size_t size() const { return size_; } + /*! + * \brief Count the number of times a key exists in the hash map + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const; + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const; + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key); + /*! \return begin iterator */ + iterator begin() const; + /*! \return end iterator */ + iterator end() const; + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const; + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position); + /*! + * \brief Erase the entry associated with the key, do nothing if not exists + * \param key The indexing key + */ + void erase(const key_type& key) { erase(find(key)); } + + class iterator { + public: + using iterator_category = std::forward_iterator_tag; + using difference_type = int64_t; + using value_type = KVType; + using pointer = KVType*; + using reference = KVType&; + /*! \brief Default constructor */ + iterator() : index(0), self(nullptr) {} + /*! \brief Compare iterators */ + bool operator==(const iterator& other) const { + return index == other.index && self == other.self; + } + /*! \brief Compare iterators */ + bool operator!=(const iterator& other) const { return !(*this == other); } + /*! \brief De-reference iterators */ + pointer operator->() const; + /*! \brief De-reference iterators */ + reference operator*() const { return *((*this).operator->()); } + /*! \brief Prefix self increment, e.g. ++iter */ + iterator& operator++(); + /*! \brief Prefix self decrement, e.g. --iter */ + iterator& operator--(); + /*! \brief Suffix self increment */ + iterator operator++(int) { + iterator copy = *this; + ++(*this); + return copy; + } + /*! \brief Suffix self decrement */ + iterator operator--(int) { + iterator copy = *this; + --(*this); + return copy; + } + + protected: + /*! \brief Construct by value */ + iterator(uint64_t index, const MapNode* self) : index(index), self(self) {} + /*! \brief The position on the array */ + uint64_t index; + /*! \brief The container it points to */ + const MapNode* self; + + friend class DenseMapNode; + friend class SmallMapNode; + }; + /*! + * \brief Create an empty container + * \return The object created + */ + static inline ObjectPtr Empty(); + + protected: + /*! + * \brief Create the map using contents from the given iterators. + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return ObjectPtr to the map created + */ + template + static inline ObjectPtr CreateFromRange(IterType first, IterType last); + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static inline void InsertMaybeReHash(const KVType& kv, ObjectPtr* map); + /*! + * \brief Create an empty container with elements copying from another SmallMapNode + * \param from The source container + * \return The object created + */ + static inline ObjectPtr CopyFrom(MapNode* from); + /*! \brief number of slots minus 1 */ + uint64_t slots_; + /*! \brief number of entries in the container */ + uint64_t size_; + // Reference class + template + friend class Map; +}; + +/*! \brief A specialization of small-sized hash map */ +class SmallMapNode : public MapNode, + public runtime::InplaceArrayBase { + private: + static constexpr uint64_t kInitSize = 2; + static constexpr uint64_t kMaxSize = 4; + + public: + using MapNode::iterator; + using MapNode::KVType; + + /*! \brief Defaults to the destructor of InplaceArrayBase */ + ~SmallMapNode() = default; + /*! + * \brief Count the number of times a key exists in the SmallMapNode + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const { return find(key).index < size_; } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { + iterator itr = find(key); + ICHECK(itr.index < size_) << "IndexError: key is not in Map"; + return itr->second; + } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { + iterator itr = find(key); + ICHECK(itr.index < size_) << "IndexError: key is not in Map"; + return itr->second; + } + /*! \return begin iterator */ + iterator begin() const { return iterator(0, this); } + /*! \return end iterator */ + iterator end() const { return iterator(size_, this); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const { + KVType* ptr = static_cast(AddressOf(0)); + for (uint64_t i = 0; i < size_; ++i, ++ptr) { + if (ObjectEqual()(ptr->first, key)) { + return iterator(i, this); + } + } + return iterator(size_, this); + } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { Erase(position.index); } + + private: + /*! + * \brief Remove a position in SmallMapNode + * \param index The position to be removed + */ + void Erase(const uint64_t index) { + if (index >= size_) { + return; + } + KVType* begin = static_cast(AddressOf(0)); + KVType* last = begin + (size_ - 1); + if (index + 1 == size_) { + last->first.ObjectRef::~ObjectRef(); + last->second.ObjectRef::~ObjectRef(); + } else { + *(begin + index) = std::move(*last); + } + size_ -= 1; + } + /*! + * \brief Create an empty container + * \param n Number of empty slots + * \return The object created + */ + static ObjectPtr Empty(uint64_t n = kInitSize) { + using ::tvm::runtime::make_inplace_array_object; + ObjectPtr p = make_inplace_array_object(n); + p->size_ = 0; + p->slots_ = n; + return p; + } + /*! + * \brief Create an empty container initialized with a given range + * \param n Number of empty slots + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + * \return The object created + */ + template + static ObjectPtr CreateFromRange(uint64_t n, IterType first, IterType last) { + ObjectPtr p = Empty(n); + KVType* ptr = static_cast(p->AddressOf(0)); + for (; first != last; ++first, ++p->size_) { + new (ptr++) KVType(*first); + } + return p; + } + /*! + * \brief Create an empty container with elements copying from another SmallMapNode + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(SmallMapNode* from) { + KVType* first = static_cast(from->AddressOf(0)); + KVType* last = first + from->size_; + return CreateFromRange(from->size_, first, last); + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + SmallMapNode* map_node = static_cast(map->get()); + iterator itr = map_node->find(kv.first); + if (itr.index < map_node->size_) { + itr->second = kv.second; + return; + } + if (map_node->size_ < map_node->slots_) { + KVType* ptr = static_cast(map_node->AddressOf(map_node->size_)); + new (ptr) KVType(kv); + ++map_node->size_; + return; + } + uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize)); + next_size = std::min(next_size, uint64_t(kMaxSize)); + ICHECK_GT(next_size, map_node->slots_); + ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); + InsertMaybeReHash(kv, &new_map); + *map = std::move(new_map); + } + /*! + * \brief Increment the pointer + * \param index The pointer to be incremented + * \return The increased pointer + */ + uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; } + /*! + * \brief Decrement the pointer + * \param index The pointer to be decremented + * \return The decreased pointer + */ + uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; } + /*! + * \brief De-reference the pointer + * \param index The pointer to be dereferenced + * \return The result + */ + KVType* DeRefItr(uint64_t index) const { return static_cast(AddressOf(index)); } + /*! \brief A size function used by InplaceArrayBase */ + uint64_t GetSize() const { return size_; } + + protected: + friend class MapNode; + friend class DenseMapNode; + friend class runtime::InplaceArrayBase; +}; + +/*! \brief A specialization of hash map that implements the idea of array-based hash map. + * Another reference implementation can be found [1]. + * + * A. Overview + * + * DenseMapNode did several improvements over traditional separate chaining hash, + * in terms of cache locality, memory footprints and data organization. + * + * A1. Implicit linked list. For better cache locality, instead of using linked list + * explicitly for each bucket, we store list data into a single array that spans contiguously + * in memory, and then carefully design access patterns to make sure most of them fall into + * a single cache line. + * + * A2. 1-byte metadata. There is only 1 byte overhead for each slot in the array to indexing and + * traversal. This can be divided in 3 parts. + * 1) Reserved code: (0b11111111)_2 indicates a slot is empty; (0b11111110)_2 indicates protected, + * which means the slot is empty but not allowed to be written. + * 2) If not empty or protected, the highest bit is used to indicate whether data in the slot is + * head of a linked list. + * 3) The rest 7 bits are used as the "next pointer" (i.e. pointer to the next element). On 64-bit + * architecture, an ordinary pointer can take up to 8 bytes, which is not acceptable overhead when + * dealing with 16-byte ObjectRef pairs. Based on a commonly noticed fact that the lists are + * relatively short (length <= 3) in hash maps, we follow [1]'s idea that only allows the pointer to + * be one of the 126 possible values, i.e. if the next element of i-th slot is (i + x)-th element, + * then x must be one of the 126 pre-defined values. + * + * A3. Data blocking. We organize the array in the way that every 16 elements forms a data block. + * The 16-byte metadata of those 16 elements are stored together, followed by the real data, i.e. + * 16 key-value pairs. + * + * B. Implementation details + * + * B1. Power-of-2 table size and Fibonacci Hashing. We use power-of-two as table size to avoid + * modulo for more efficient arithmetics. To make the hash-to-slot mapping distribute more evenly, + * we use the Fibonacci Hashing [2] trick. + * + * B2. Traverse a linked list in the array. + * 1) List head. Assume Fibonacci Hashing maps a given key to slot i, if metadata at slot i + * indicates that it is list head, then we found the head; otherwise the list is empty. No probing + * is done in this procedure. 2) Next element. To find the next element of a non-empty slot i, we + * look at the last 7 bits of the metadata at slot i. If they are all zeros, then it is the end of + * list; otherwise, we know that the next element is (i + candidates[the-last-7-bits]). + * + * B3. InsertMaybeReHash an element. Following B2, we first traverse the linked list to see if this + * element is in the linked list, and if not, we put it at the end by probing the next empty + * position in one of the 126 candidate positions. If the linked list does not even exist, but the + * slot for list head has been occupied by another linked list, we should find this intruder another + * place. + * + * B4. Quadratic probing with triangle numbers. In open address hashing, it is provable that probing + * with triangle numbers can traverse power-of-2-sized table [3]. In our algorithm, we follow the + * suggestion in [1] that also use triangle numbers for "next pointer" as well as sparing for list + * head. + * + * [1] https://github.com/skarupke/flat_hash_map + * [2] https://programmingpraxis.com/2018/06/19/fibonacci-hash/ + * [3] https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ + */ +class DenseMapNode : public MapNode { + private: + /*! \brief The number of elements in a memory block */ + static constexpr int kBlockCap = 16; + /*! \brief Maximum load factor of the hash map */ + static constexpr double kMaxLoadFactor = 0.99; + /*! \brief Binary representation of the metadata of an empty slot */ + static constexpr uint8_t kEmptySlot = uint8_t(0b11111111); + /*! \brief Binary representation of the metadata of a protected slot */ + static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110); + /*! \brief Number of probing choices available */ + static constexpr int kNumJumpDists = 126; + /*! \brief Head of the implicit linked list */ + struct ListNode; + /*! \brief POD type of a block of memory */ + struct Block { + uint8_t bytes[kBlockCap + kBlockCap * sizeof(KVType)]; + }; + static_assert(sizeof(Block) == kBlockCap * (sizeof(KVType) + 1), "sizeof(Block) incorrect"); + static_assert(std::is_standard_layout::value, "Block is not standard layout"); + + public: + using MapNode::iterator; + + /*! + * \brief Destroy the DenseMapNode + */ + ~DenseMapNode() { this->Reset(); } + /*! \return The number of elements of the key */ + size_t count(const key_type& key) const { return !Search(key).IsNone(); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { return At(key); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { return At(key); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const { + ListNode node = Search(key); + return node.IsNone() ? end() : iterator(node.index, this); + } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { + uint64_t index = position.index; + if (position.self != nullptr && index <= this->slots_) { + Erase(ListNode(index, this)); + } + } + /*! \return begin iterator */ + iterator begin() const { + if (slots_ == 0) { + return iterator(0, this); + } + for (uint64_t index = 0; index <= slots_; ++index) { + if (!ListNode(index, this).IsEmpty()) { + return iterator(index, this); + } + } + return iterator(slots_ + 1, this); + } + /*! \return end iterator */ + iterator end() const { return slots_ == 0 ? iterator(0, this) : iterator(slots_ + 1, this); } + + private: + /*! + * \brief Search for the given key + * \param key The key + * \return ListNode that associated with the key + */ + ListNode Search(const key_type& key) const { + if (this->size_ == 0) { + return ListNode(); + } + for (ListNode iter = GetListHead(ObjectHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) { + if (ObjectEqual()(key, iter.Key())) { + return iter; + } + } + return ListNode(); + } + /*! + * \brief Search for the given key, throw exception if not exists + * \param key The key + * \return ListNode that associated with the key + */ + mapped_type& At(const key_type& key) const { + ListNode iter = Search(key); + ICHECK(!iter.IsNone()) << "IndexError: key is not in Map"; + return iter.Val(); + } + /*! + * \brief Try to insert a key, or do nothing if already exists + * \param key The indexing key + * \param result The linked-list entry found or just constructed + * \return A boolean, indicating if actual insertion happens + */ + bool TryInsert(const key_type& key, ListNode* result) { + if (slots_ == 0) { + return false; + } + // required that `iter` to be the head of a linked list through which we can iterator + ListNode iter = IndexFromHash(ObjectHash()(key)); + // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list + // Case 1: empty + if (iter.IsEmpty()) { + iter.NewHead(KVType(key, ObjectRef(nullptr))); + this->size_ += 1; + *result = iter; + return true; + } + // Case 2: body of an irrelevant list + if (!iter.IsHead()) { + // we move the elements around and construct the single-element linked list + return IsFull() ? false : TrySpareListHead(iter, key, result); + } + // Case 3: head of the relevant list + // we iterate through the linked list until the end + // make sure `iter` is the previous element of `next` + ListNode next = iter; + do { + // find equal item, do not insert + if (ObjectEqual()(key, next.Key())) { + *result = next; + return true; + } + // make sure `iter` is the previous element of `next` + iter = next; + } while (next.MoveToNext(this)); + // `iter` is the tail of the linked list + // always check capacity before insertion + if (IsFull()) { + return false; + } + // find the next empty slot + uint8_t jump; + if (!iter.GetNextEmpty(this, &jump, result)) { + return false; + } + result->NewTail(KVType(key, ObjectRef(nullptr))); + // link `iter` to `empty`, and move forward + iter.SetJump(jump); + this->size_ += 1; + return true; + } + /*! + * \brief Spare an entry to be the head of a linked list. + * As described in B3, during insertion, it is possible that the entire linked list does not + * exist, but the slot of its head has been occupied by other linked lists. In this case, we need + * to spare the slot by moving away the elements to another valid empty one to make insertion + * possible. + * \param target The given entry to be spared + * \param key The indexing key + * \param result The linked-list entry constructed as the head + * \return A boolean, if actual insertion happens + */ + bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) { + // `target` is not the head of the linked list + // move the original item of `target` (if any) + // and construct new item on the position `target` + // To make `target` empty, we + // 1) find `w` the previous element of `target` in the linked list + // 2) copy the linked list starting from `r = target` + // 3) paste them after `w` + // read from the linked list after `r` + ListNode r = target; + // write to the tail of `w` + ListNode w = target.FindPrev(this); + // after `target` is moved, we disallow writing to the slot + bool is_first = true; + uint8_t r_meta, jump; + ListNode empty; + do { + // `jump` describes how `w` is jumped to `empty` + // rehash if there is no empty space after `w` + if (!w.GetNextEmpty(this, &jump, &empty)) { + return false; + } + // move `r` to `empty` + empty.NewTail(std::move(r.Data())); + // clear the metadata of `r` + r_meta = r.Meta(); + if (is_first) { + is_first = false; + r.SetProtected(); + } else { + r.SetEmpty(); + } + // link `w` to `empty`, and move forward + w.SetJump(jump); + w = empty; + // move `r` forward as well + } while (r.MoveToNext(this, r_meta)); + // finally we have done moving the linked list + // fill data_ into `target` + target.NewHead(KVType(key, ObjectRef(nullptr))); + this->size_ += 1; + *result = target; + return true; + } + /*! + * \brief Remove a ListNode + * \param iter The node to be removed + */ + void Erase(const ListNode& iter) { + this->size_ -= 1; + if (!iter.HasNext()) { + // `iter` is the last + if (!iter.IsHead()) { + // cut the link if there is any + iter.FindPrev(this).SetJump(0); + } + iter.Data().KVType::~KVType(); + iter.SetEmpty(); + } else { + ListNode last = iter, prev = iter; + for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) { + } + iter.Data() = std::move(last.Data()); + last.SetEmpty(); + prev.SetJump(0); + } + } + /*! \brief Clear the container to empty, release all entries and memory acquired */ + void Reset() { + uint64_t n_blocks = CalcNumBlocks(this->slots_); + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr = data_[bi].bytes; + KVType* data_ptr = reinterpret_cast(data_[bi].bytes + kBlockCap); + for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { + uint8_t& meta = *meta_ptr; + if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { + meta = uint8_t(kEmptySlot); + data_ptr->KVType::~KVType(); + } + } + } + ReleaseMemory(); + } + /*! \brief Release the memory acquired by the container without deleting its entries stored inside + */ + void ReleaseMemory() { + delete[] data_; + data_ = nullptr; + slots_ = 0; + size_ = 0; + fib_shift_ = 63; + } + /*! + * \brief Create an empty container + * \param fib_shift The fib shift provided + * \param n_slots Number of slots required, should be power-of-two + * \return The object created + */ + static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { + ICHECK_GT(n_slots, uint64_t(SmallMapNode::kMaxSize)); + ObjectPtr p = make_object(); + uint64_t n_blocks = CalcNumBlocks(n_slots - 1); + Block* block = p->data_ = new Block[n_blocks]; + p->slots_ = n_slots - 1; + p->size_ = 0; + p->fib_shift_ = fib_shift; + for (uint64_t i = 0; i < n_blocks; ++i, ++block) { + std::fill(block->bytes, block->bytes + kBlockCap, uint8_t(kEmptySlot)); + } + return p; + } + /*! + * \brief Create an empty container with elements copying from another DenseMapNode + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(DenseMapNode* from) { + ObjectPtr p = make_object(); + uint64_t n_blocks = CalcNumBlocks(from->slots_); + p->data_ = new Block[n_blocks]; + p->slots_ = from->slots_; + p->size_ = from->size_; + p->fib_shift_ = from->fib_shift_; + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr_from = from->data_[bi].bytes; + KVType* data_ptr_from = reinterpret_cast(from->data_[bi].bytes + kBlockCap); + uint8_t* meta_ptr_to = p->data_[bi].bytes; + KVType* data_ptr_to = reinterpret_cast(p->data_[bi].bytes + kBlockCap); + for (int j = 0; j < kBlockCap; + ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) { + uint8_t& meta = *meta_ptr_to = *meta_ptr_from; + ICHECK(meta != kProtectedSlot); + if (meta != uint8_t(kEmptySlot)) { + new (data_ptr_to) KVType(*data_ptr_from); + } + } + } + return p; + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + DenseMapNode* map_node = static_cast(map->get()); + ListNode iter; + // Try to insert. If succeed, we simply return + if (map_node->TryInsert(kv.first, &iter)) { + iter.Val() = kv.second; + return; + } + ICHECK_GT(map_node->slots_, uint64_t(SmallMapNode::kMaxSize)); + // Otherwise, start rehash + ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2 + 2); + // Insert the given `kv` into the new hash map + InsertMaybeReHash(kv, &p); + uint64_t n_blocks = CalcNumBlocks(map_node->slots_); + // Then Insert data from the original block. + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr = map_node->data_[bi].bytes; + KVType* data_ptr = reinterpret_cast(map_node->data_[bi].bytes + kBlockCap); + for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { + uint8_t& meta = *meta_ptr; + if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) { + meta = uint8_t(kEmptySlot); + KVType kv = std::move(*data_ptr); + InsertMaybeReHash(kv, &p); + } + } + } + map_node->ReleaseMemory(); + *map = p; + } + /*! + * \brief Check whether the hash table is full + * \return A boolean indicating whether hash table is full + */ + bool IsFull() const { return size_ + 1 > (slots_ + 1) * kMaxLoadFactor; } + /*! + * \brief Increment the pointer + * \param index The pointer to be incremented + * \return The increased pointer + */ + uint64_t IncItr(uint64_t index) const { + for (++index; index <= slots_; ++index) { + if (!ListNode(index, this).IsEmpty()) { + return index; + } + } + return slots_ + 1; + } + /*! + * \brief Decrement the pointer + * \param index The pointer to be decremented + * \return The decreased pointer + */ + uint64_t DecItr(uint64_t index) const { + while (index != 0) { + index -= 1; + if (!ListNode(index, this).IsEmpty()) { + return index; + } + } + return slots_ + 1; + } + /*! + * \brief De-reference the pointer + * \param index The pointer to be dereferenced + * \return The result + */ + KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); } + /*! \brief Construct from hash code */ + ListNode IndexFromHash(uint64_t hash_value) const { + return ListNode(FibHash(hash_value, fib_shift_), this); + } + /*! \brief Construct from hash code if the position is head of list */ + ListNode GetListHead(uint64_t hash_value) const { + ListNode node = IndexFromHash(hash_value); + return node.IsHead() ? node : ListNode(); + } + /*! \brief Construct the number of blocks in the hash table */ + static uint64_t CalcNumBlocks(uint64_t n_slots_m1) { + uint64_t n_slots = n_slots_m1 > 0 ? n_slots_m1 + 1 : 0; + return (n_slots + kBlockCap - 1) / kBlockCap; + } + /*! + * \brief Calculate the power-of-2 table size given the lower-bound of required capacity. + * \param cap The lower-bound of the required capacity + * \param fib_shift The result shift for Fibonacci Hashing + * \param n_slots The result number of slots + */ + static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) { + uint32_t shift = 64; + uint64_t slots = 1; + for (uint64_t c = cap; c; c >>= 1) { + shift -= 1; + slots <<= 1; + } + ICHECK_GT(slots, cap); + if (slots < cap * 2) { + *fib_shift = shift - 1; + *n_slots = slots << 1; + } else { + *fib_shift = shift; + *n_slots = slots; + } + } + /*! + * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table. + * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/. + * \param hash_value The raw hash value + * \param fib_shift The shift in Fibonacci Hashing + * \return An index calculated using Fibonacci Hashing + */ + static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) { + constexpr uint64_t coeff = 11400714819323198485ull; + return (coeff * hash_value) >> fib_shift; + } + /*! \brief The implicit in-place linked list used to index a chain */ + struct ListNode { + /*! \brief Construct None */ + ListNode() : index(0), block(nullptr) {} + /*! \brief Construct from position */ + ListNode(uint64_t index, const DenseMapNode* self) + : index(index), block(self->data_ + (index / kBlockCap)) {} + /*! \brief Metadata on the entry */ + uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); } + /*! \brief Data on the entry */ + KVType& Data() const { + return *(reinterpret_cast(block->bytes + kBlockCap + + (index % kBlockCap) * sizeof(KVType))); + } + /*! \brief Key on the entry */ + key_type& Key() const { return Data().first; } + /*! \brief Value on the entry */ + mapped_type& Val() const { return Data().second; } + /*! \brief If the entry is head of linked list */ + bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; } + /*! \brief If the entry is none */ + bool IsNone() const { return block == nullptr; } + /*! \brief If the entry is empty slot */ + bool IsEmpty() const { return Meta() == uint8_t(kEmptySlot); } + /*! \brief If the entry is protected slot */ + bool IsProtected() const { return Meta() == uint8_t(kProtectedSlot); } + /*! \brief Set the entry to be empty */ + void SetEmpty() const { Meta() = uint8_t(kEmptySlot); } + /*! \brief Set the entry to be protected */ + void SetProtected() const { Meta() = uint8_t(kProtectedSlot); } + /*! \brief Set the entry's jump to its next entry */ + void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; } + /*! \brief Construct a head of linked list in-place */ + void NewHead(KVType v) const { + Meta() = 0b00000000; + new (&Data()) KVType(std::move(v)); + } + /*! \brief Construct a tail of linked list in-place */ + void NewTail(KVType v) const { + Meta() = 0b10000000; + new (&Data()) KVType(std::move(v)); + } + /*! \brief If the entry has next entry on the linked list */ + bool HasNext() const { return kNextProbeLocation[Meta() & 0b01111111] != 0; } + /*! \brief Move the entry to the next entry on the linked list */ + bool MoveToNext(const DenseMapNode* self, uint8_t meta) { + uint64_t offset = kNextProbeLocation[meta & 0b01111111]; + if (offset == 0) { + index = 0; + block = nullptr; + return false; + } + index = (index + offset) & (self->slots_); + block = self->data_ + (index / kBlockCap); + return true; + } + /*! \brief Move the entry to the next entry on the linked list */ + bool MoveToNext(const DenseMapNode* self) { return MoveToNext(self, Meta()); } + /*! \brief Get the previous entry on the linked list */ + ListNode FindPrev(const DenseMapNode* self) const { + // start from the head of the linked list, which must exist + ListNode next = self->IndexFromHash(ObjectHash()(Key())); + // `prev` is always the previous item of `next` + ListNode prev = next; + for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) { + } + return prev; + } + /*! \brief Get the next empty jump */ + bool GetNextEmpty(const DenseMapNode* self, uint8_t* jump, ListNode* result) const { + for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { + ListNode candidate((index + kNextProbeLocation[idx]) & (self->slots_), self); + if (candidate.IsEmpty()) { + *jump = idx; + *result = candidate; + return true; + } + } + return false; + } + /*! \brief Index on the real array */ + uint64_t index; + /*! \brief Pointer to the actual block */ + Block* block; + }; + + protected: + /*! \brief fib shift in Fibonacci Hashing */ + uint32_t fib_shift_; + /*! \brief array of data blocks */ + Block* data_; + /* clang-format off */ + /*! \brief Candidates of probing distance */ + TVM_DLL static constexpr uint64_t kNextProbeLocation[kNumJumpDists] { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + // Quadratic probing with triangle numbers. See also: + // 1) https://en.wikipedia.org/wiki/Quadratic_probing + // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ + // 3) https://github.com/skarupke/flat_hash_map + 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, + 136, 153, 171, 190, 210, 231, 253, 276, 300, 325, + 351, 378, 406, 435, 465, 496, 528, 561, 595, 630, + 666, 703, 741, 780, 820, 861, 903, 946, 990, 1035, + 1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540, + 1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145, + 2211, 2278, 2346, 2415, 2485, 2556, 2628, + // larger triangle numbers + 8515, 19110, 42778, 96141, 216153, + 486591, 1092981, 2458653, 5532801, 12442566, + 27993903, 62983476, 141717030, 318844378, 717352503, + 1614057336, 3631522476, 8170957530, 18384510628, 41364789378, + 93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695, + 5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000, + 309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701, + 17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, 457381325854679626, + 1029107982097042876, 2315492959180353330, 5209859154120846435, + }; + /* clang-format on */ + friend class MapNode; +}; + +#define TVM_DISPATCH_MAP(base, var, body) \ + { \ + using TSmall = SmallMapNode*; \ + using TDense = DenseMapNode*; \ + uint64_t slots = base->slots_; \ + if (slots <= SmallMapNode::kMaxSize) { \ + TSmall var = static_cast(base); \ + body; \ + } else { \ + TDense var = static_cast(base); \ + body; \ + } \ + } + +#define TVM_DISPATCH_MAP_CONST(base, var, body) \ + { \ + using TSmall = const SmallMapNode*; \ + using TDense = const DenseMapNode*; \ + uint64_t slots = base->slots_; \ + if (slots <= SmallMapNode::kMaxSize) { \ + TSmall var = static_cast(base); \ + body; \ + } else { \ + TDense var = static_cast(base); \ + body; \ + } \ + } + +inline MapNode::iterator::pointer MapNode::iterator::operator->() const { + TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); +} + +inline MapNode::iterator& MapNode::iterator::operator++() { + TVM_DISPATCH_MAP_CONST(self, p, { + index = p->IncItr(index); + return *this; + }); +} + +inline MapNode::iterator& MapNode::iterator::operator--() { + TVM_DISPATCH_MAP_CONST(self, p, { + index = p->DecItr(index); + return *this; + }); +} + +inline size_t MapNode::count(const key_type& key) const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->count(key); }); +} + +inline const MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->at(key); }); +} + +inline MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) { + TVM_DISPATCH_MAP(this, p, { return p->at(key); }); +} + +inline MapNode::iterator MapNode::begin() const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->begin(); }); +} + +inline MapNode::iterator MapNode::end() const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->end(); }); +} + +inline MapNode::iterator MapNode::find(const MapNode::key_type& key) const { + TVM_DISPATCH_MAP_CONST(this, p, { return p->find(key); }); +} + +inline void MapNode::erase(const MapNode::iterator& position) { + TVM_DISPATCH_MAP(this, p, { return p->erase(position); }); +} + +#undef TVM_DISPATCH_MAP +#undef TVM_DISPATCH_MAP_CONST + +inline ObjectPtr MapNode::Empty() { return SmallMapNode::Empty(); } + +inline ObjectPtr MapNode::CopyFrom(MapNode* from) { + if (from->slots_ <= SmallMapNode::kMaxSize) { + return SmallMapNode::CopyFrom(static_cast(from)); + } else { + return DenseMapNode::CopyFrom(static_cast(from)); + } +} + +template +inline ObjectPtr MapNode::CreateFromRange(IterType first, IterType last) { + int64_t _cap = std::distance(first, last); + if (_cap < 0) { + return SmallMapNode::Empty(); + } + uint64_t cap = static_cast(_cap); + if (cap < SmallMapNode::kMaxSize) { + return SmallMapNode::CreateFromRange(cap, first, last); + } + uint32_t fib_shift; + uint64_t n_slots; + DenseMapNode::CalcTableSize(cap, &fib_shift, &n_slots); + ObjectPtr obj = DenseMapNode::Empty(fib_shift, n_slots); + for (; first != last; ++first) { + KVType kv(*first); + DenseMapNode::InsertMaybeReHash(kv, &obj); + } + return obj; +} + +inline void MapNode::InsertMaybeReHash(const KVType& kv, ObjectPtr* map) { + constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize; + MapNode* base = static_cast(map->get()); + if (base->slots_ < kSmallMapMaxSize) { + SmallMapNode::InsertMaybeReHash(kv, map); + } else if (base->slots_ == kSmallMapMaxSize) { + if (base->size_ < base->slots_) { + SmallMapNode::InsertMaybeReHash(kv, map); + } else { + ObjectPtr new_map = MapNode::CreateFromRange(base->begin(), base->end()); + DenseMapNode::InsertMaybeReHash(kv, &new_map); + *map = std::move(new_map); + } + } else { + DenseMapNode::InsertMaybeReHash(kv, map); + } +} + +template <> +inline ObjectPtr make_object<>() = delete; + +#endif + +/*! + * \brief Map container of NodeRef->NodeRef in DSL graph. + * Map implements copy on write semantics, which means map is mutable + * but copy will happen when array is referenced in more than two places. + * + * operator[] only provide const acces, use Set to mutate the content. + * \tparam K The key NodeRef type. + * \tparam V The value NodeRef type. + */ +template ::value>::type, + typename = typename std::enable_if::value>::type> +class Map : public ObjectRef { + public: + using key_type = K; + using mapped_type = V; + class iterator; + /*! + * \brief default constructor + */ + Map() { data_ = MapNode::Empty(); } + /*! + * \brief move constructor + * \param other source + */ + Map(Map&& other) { data_ = std::move(other.data_); } + /*! + * \brief copy constructor + * \param other source + */ + Map(const Map& other) : ObjectRef(other.data_) {} + /*! + * \brief copy assign operator + * \param other The source of assignment + * \return reference to self. + */ + Map& operator=(Map&& other) { + data_ = std::move(other.data_); + return *this; + } + /*! + * \brief move assign operator + * \param other The source of assignment + * \return reference to self. + */ + Map& operator=(const Map& other) { + data_ = other.data_; + return *this; + } + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Map(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + Map(IterType begin, IterType end) { + data_ = MapNode::CreateFromRange(begin, end); + } + /*! + * \brief constructor from initializer list + * \param init The initalizer list + */ + Map(std::initializer_list> init) { + data_ = MapNode::CreateFromRange(init.begin(), init.end()); + } + /*! + * \brief constructor from unordered_map + * \param init The unordered_map + */ + template + Map(const std::unordered_map& init) { // NOLINT(*) + data_ = MapNode::CreateFromRange(init.begin(), init.end()); + } + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + const V at(const K& key) const { return DowncastNoCheck(GetMapNode()->at(key)); } + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + const V operator[](const K& key) const { return this->at(key); } + /*! \return The size of the array */ + size_t size() const { + MapNode* n = GetMapNode(); + return n == nullptr ? 0 : n->size(); + } + /*! \return The number of elements of the key */ + size_t count(const K& key) const { + MapNode* n = GetMapNode(); + return n == nullptr ? 0 : GetMapNode()->count(key); + } + /*! \return whether array is empty */ + bool empty() const { return size() == 0; } + /*! \brief Release reference to all the elements */ + void clear() { + MapNode* n = GetMapNode(); + if (n != nullptr) { + data_ = MapNode::Empty(); + } + } + /*! + * \brief set the Map. + * \param key The index key. + * \param value The value to be setted. + */ + void Set(const K& key, const V& value) { + CopyOnWrite(); + MapNode::InsertMaybeReHash(MapNode::KVType(key, value), &data_); + } + /*! \return begin iterator */ + iterator begin() const { return iterator(GetMapNode()->begin()); } + /*! \return end iterator */ + iterator end() const { return iterator(GetMapNode()->end()); } + /*! \return find the key and returns the associated iterator */ + iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); } + + void erase(const K& key) { CopyOnWrite()->erase(key); } + + /*! + * \brief copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which ganrantees to be unique) + */ + MapNode* CopyOnWrite() { + if (data_.get() == nullptr) { + data_ = MapNode::Empty(); + } else if (!data_.unique()) { + data_ = MapNode::CopyFrom(GetMapNode()); + } + return GetMapNode(); + } + /*! \brief specify container node */ + using ContainerType = MapNode; + + /*! \brief Iterator of the hash map */ + class iterator { + public: + using iterator_category = std::bidirectional_iterator_tag; + using difference_type = int64_t; + using value_type = const std::pair; + using pointer = value_type*; + using reference = value_type; + + iterator() : itr() {} + + /*! \brief Compare iterators */ + bool operator==(const iterator& other) const { return itr == other.itr; } + /*! \brief Compare iterators */ + bool operator!=(const iterator& other) const { return itr != other.itr; } + /*! \brief De-reference iterators is not allowed */ + pointer operator->() const = delete; + /*! \brief De-reference iterators */ + reference operator*() const { + auto& kv = *itr; + return std::make_pair(DowncastNoCheck(kv.first), DowncastNoCheck(kv.second)); + } + /*! \brief Prefix self increment, e.g. ++iter */ + iterator& operator++() { + ++itr; + return *this; + } + /*! \brief Suffix self increment */ + iterator operator++(int) { + iterator copy = *this; + ++(*this); + return copy; + } + + private: + iterator(const MapNode::iterator& itr) // NOLINT(*) + : itr(itr) {} + + template + friend class Map; + + MapNode::iterator itr; + }; + + private: + /*! \brief Return data_ as type of pointer of MapNode */ + MapNode* GetMapNode() const { return static_cast(data_.get()); } +}; + +/*! + * \brief Merge two Maps. + * \param lhs the first Map to merge. + * \param rhs the second Map to merge. + * @return The merged Array. Original Maps are kept unchanged. + */ +template ::value>::type, + typename = typename std::enable_if::value>::type> +inline Map Merge(Map lhs, const Map& rhs) { + for (const auto& p : rhs) { + lhs.Set(p.first, p.second); + } + return std::move(lhs); +} + +} // namespace runtime + +// expose the functions to the root namespace. +using runtime::Map; +using runtime::MapNode; +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_MAP_H_ diff --git a/include/tvm/runtime/container/optional.h b/include/tvm/runtime/container/optional.h new file mode 100644 index 000000000000..bea4228c48b8 --- /dev/null +++ b/include/tvm/runtime/container/optional.h @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/optional.h + * \brief Runtime Optional container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_OPTIONAL_H_ +#define TVM_RUNTIME_CONTAINER_OPTIONAL_H_ + +#include + +#include "./base.h" + +namespace tvm { +namespace runtime { + +/*! \brief Helper to represent nullptr for optional. */ +struct NullOptType {}; + +/*! + * \brief Optional container that to represent to a Nullable variant of T. + * \tparam T The original ObjectRef. + * + * \code + * + * Optional opt0 = nullptr; + * Optional opt1 = String("xyz"); + * ICHECK(opt0 == nullptr); + * ICHECK(opt1 == "xyz"); + * + * \endcode + */ +template +class Optional : public ObjectRef { + public: + using ContainerType = typename T::ContainerType; + static_assert(std::is_base_of::value, "Optional is only defined for ObjectRef."); + // default constructors. + Optional() = default; + Optional(const Optional&) = default; + Optional(Optional&&) = default; + Optional& operator=(const Optional&) = default; + Optional& operator=(Optional&&) = default; + /*! + * \brief Construct from an ObjectPtr + * whose type already matches the ContainerType. + * \param ptr + */ + explicit Optional(ObjectPtr ptr) : ObjectRef(ptr) {} + /*! \brief Nullopt handling */ + Optional(NullOptType) {} // NOLINT(*) + // nullptr handling. + // disallow implicit conversion as 0 can be implicitly converted to nullptr_t + explicit Optional(std::nullptr_t) {} + Optional& operator=(std::nullptr_t) { + data_ = nullptr; + return *this; + } + // normal value handling. + Optional(T other) // NOLINT(*) + : ObjectRef(std::move(other)) {} + Optional& operator=(T other) { + ObjectRef::operator=(std::move(other)); + return *this; + } + // delete the int constructor + // since Optional(0) is ambiguious + // 0 can be implicitly casted to nullptr_t + explicit Optional(int val) = delete; + Optional& operator=(int val) = delete; + /*! + * \return A not-null container value in the optional. + * \note This function performs not-null checking. + */ + T value() const { + ICHECK(data_ != nullptr); + return T(data_); + } + /*! + * \return The contained value if the Optional is not null + * otherwise return the default_value. + */ + T value_or(T default_value) const { return data_ != nullptr ? T(data_) : default_value; } + + /*! \return Whether the container is not nullptr.*/ + explicit operator bool() const { return *this != nullptr; } + // operator overloadings + bool operator==(std::nullptr_t) const { return data_ == nullptr; } + bool operator!=(std::nullptr_t) const { return data_ != nullptr; } + auto operator==(const Optional& other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(value() == other.value()); + if (same_as(other)) return RetType(true); + if (*this != nullptr && other != nullptr) { + return value() == other.value(); + } else { + // one of them is nullptr. + return RetType(false); + } + } + auto operator!=(const Optional& other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(value() != other.value()); + if (same_as(other)) return RetType(false); + if (*this != nullptr && other != nullptr) { + return value() != other.value(); + } else { + // one of them is nullptr. + return RetType(true); + } + } + auto operator==(const T& other) const { + using RetType = decltype(value() == other); + if (same_as(other)) return RetType(true); + if (*this != nullptr) return value() == other; + return RetType(false); + } + auto operator!=(const T& other) const { return !(*this == other); } + template + auto operator==(const U& other) const { + using RetType = decltype(value() == other); + if (*this == nullptr) return RetType(false); + return value() == other; + } + template + auto operator!=(const U& other) const { + using RetType = decltype(value() != other); + if (*this == nullptr) return RetType(true); + return value() != other; + } + static constexpr bool _type_is_nullable = true; +}; + +} // namespace runtime + +// expose the functions to the root namespace. +using runtime::Optional; +constexpr runtime::NullOptType NullOpt{}; +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_OPTIONAL_H_ diff --git a/include/tvm/runtime/container/string.h b/include/tvm/runtime/container/string.h new file mode 100644 index 000000000000..664d19818be1 --- /dev/null +++ b/include/tvm/runtime/container/string.h @@ -0,0 +1,523 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/string.h + * \brief Runtime String container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_STRING_H_ +#define TVM_RUNTIME_CONTAINER_STRING_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +// We use c++14 std::experimental::string_view for optimizing hash computation +// only right now, its usage is limited in this file. Any broader usage of +// std::experiment in our core codebase is discouraged and needs community +// discussion for each use case. Reference for feature test macros of +// string_view: +// https://isocpp.org/std/standing-documents/sd-6-sg10-feature-test-recommendations +// https://en.cppreference.com/w/User:D41D8CD98F/feature_testing_macros +#if defined(__cpp_lib_experimental_string_view) && __cpp_lib_experimental_string_view >= 201411 +#define TVM_USE_CXX14_STRING_VIEW_HASH 1 +#else +#define TVM_USE_CXX14_STRING_VIEW_HASH 0 +#endif + +// Tested with clang version 9.0.1 and c++17. It will detect string_view support +// correctly. +#if defined(__cpp_lib_string_view) && __cpp_lib_string_view >= 201606 +#define TVM_USE_CXX17_STRING_VIEW_HASH 1 +#else +#define TVM_USE_CXX17_STRING_VIEW_HASH 0 +#endif + +#if TVM_USE_CXX17_STRING_VIEW_HASH +#include +#elif TVM_USE_CXX14_STRING_VIEW_HASH +#include +#endif + +#include +#include +#include + +namespace llvm { +// String to llvm object compatibility. +class StringRef; +} // namespace llvm + +namespace tvm { +namespace runtime { + +// Forward declare TVMArgValue +class TVMArgValue; + +/*! \brief An object representing string. It's POD type. */ +class StringObj : public Object { + public: + /*! \brief The pointer to string data. */ + const char* data; + + /*! \brief The length of the string object. */ + uint64_t size; + + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString; + static constexpr const char* _type_key = "runtime.String"; + TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object); + + private: + /*! \brief String object which is moved from std::string container. */ + class FromStd; + + friend class String; +}; + +/*! + * \brief Reference to string objects. + * + * \code + * + * // Example to create runtime String reference object from std::string + * std::string s = "hello world"; + * + * // You can create the reference from existing std::string + * String ref{std::move(s)}; + * + * // You can rebind the reference to another string. + * ref = std::string{"hello world2"}; + * + * // You can use the reference as hash map key + * std::unordered_map m; + * m[ref] = 1; + * + * // You can compare the reference object with other string objects + * assert(ref == "hello world", true); + * + * // You can convert the reference to std::string again + * string s2 = (string)ref; + * + * \endcode + */ +class String : public ObjectRef { + public: + /*! + * \brief Construct an empty string. + */ + String() : String(std::string()) {} + /*! + * \brief Construct a new String object + * + * \param other The moved/copied std::string object + * + * \note If user passes const reference, it will trigger copy. If it's rvalue, + * it will be moved into other. + */ + String(std::string other); // NOLINT(*) + + /*! + * \brief Construct a new String object + * + * \param other a char array. + */ + String(const char* other) // NOLINT(*) + : String(std::string(other)) {} + + /*! + * \brief Change the value the reference object points to. + * + * \param other The value for the new String + * + */ + inline String& operator=(std::string other); + + /*! + * \brief Change the value the reference object points to. + * + * \param other The value for the new String + */ + inline String& operator=(const char* other); + + /*! + * \brief Compares this String object to other + * + * \param other The String to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const String& other) const { + return memncmp(data(), other.data(), size(), other.size()); + } + + /*! + * \brief Compares this String object to other + * + * \param other The string to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const std::string& other) const { + return memncmp(data(), other.data(), size(), other.size()); + } + + /*! + * \brief Compares this to other + * + * \param other The character array to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const char* other) const { + return memncmp(data(), other, size(), std::strlen(other)); + } + + /*! + * \brief Returns a pointer to the char array in the string. + * + * \return const char* + */ + const char* c_str() const { return get()->data; } + + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + size_t size() const { + const auto* ptr = get(); + return ptr->size; + } + + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + size_t length() const { return size(); } + + /*! + * \brief Retun if the string is empty + * + * \return true if empty, false otherwise. + */ + bool empty() const { return size() == 0; } + + /*! + * \brief Read an element. + * \param pos The position at which to read the character. + * + * \return The char at position + */ + char at(size_t pos) const { + if (pos < size()) { + return data()[pos]; + } else { + throw std::out_of_range("tvm::String index out of bounds"); + } + } + + /*! + * \brief Return the data pointer + * + * \return const char* data pointer + */ + const char* data() const { return get()->data; } + + /*! + * \brief Convert String to an std::string object + * + * \return std::string + */ + operator std::string() const { return std::string{get()->data, size()}; } + + // LLVM compatibility function, implemented in src/target/llvm/llvm_common.h + /*! + * \brief Convert String to an llvm::StringRef object + * + * \return llvm::StringRef + */ + inline operator llvm::StringRef() const; + + /*! + * \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String + * \param val The value to be checked + * \return A boolean indicating if val can be converted to String + */ + inline static bool CanConvertFrom(const TVMArgValue& val); + + /*! + * \brief Hash the binary bytes + * \param data The data pointer + * \param size The size of the bytes. + * \return the hash value. + */ + static size_t HashBytes(const char* data, size_t size) { + // This function falls back to string copy with c++11 compiler and is + // recommended to be compiled with c++14 +#if TVM_USE_CXX17_STRING_VIEW_HASH + return std::hash()(std::string_view(data, size)); +#elif TVM_USE_CXX14_STRING_VIEW_HASH + return std::hash()(std::experimental::string_view(data, size)); +#else + return std::hash()(std::string(data, size)); +#endif + } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); + + private: + /*! + * \brief Compare two char sequence + * + * \param lhs Pointers to the char array to compare + * \param rhs Pointers to the char array to compare + * \param lhs_count Length of the char array to compare + * \param rhs_count Length of the char array to compare + * \return int zero if both char sequences compare equal. negative if this + * appear before other, positive otherwise. + */ + static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count); + + /*! + * \brief Concatenate two char sequences + * + * \param lhs Pointers to the lhs char array + * \param lhs_size The size of the lhs char array + * \param rhs Pointers to the rhs char array + * \param rhs_size The size of the rhs char array + * + * \return The concatenated char sequence + */ + static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) { + std::string ret(lhs, lhs_size); + ret.append(rhs, rhs_size); + return String(ret); + } + + // Overload + operator + friend String operator+(const String& lhs, const String& rhs); + friend String operator+(const String& lhs, const std::string& rhs); + friend String operator+(const std::string& lhs, const String& rhs); + friend String operator+(const String& lhs, const char* rhs); + friend String operator+(const char* lhs, const String& rhs); + + friend struct tvm::runtime::ObjectEqual; +}; + +/*! \brief An object representing string moved from std::string. */ +class StringObj::FromStd : public StringObj { + public: + /*! + * \brief Construct a new FromStd object + * + * \param other The moved/copied std::string object + * + * \note If user passes const reference, it will trigger copy. If it's rvalue, + * it will be moved into other. + */ + explicit FromStd(std::string other) : data_container{other} {} + + private: + /*! \brief Container that holds the memory. */ + std::string data_container; + + friend class String; +}; + +inline String::String(std::string other) { + auto ptr = make_object(std::move(other)); + ptr->size = ptr->data_container.size(); + ptr->data = ptr->data_container.data(); + data_ = std::move(ptr); +} + +inline String& String::operator=(std::string other) { + String replace{std::move(other)}; + data_.swap(replace.data_); + return *this; +} + +inline String& String::operator=(const char* other) { return operator=(std::string(other)); } + +inline String operator+(const String& lhs, const String& rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const String& lhs, const std::string& rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const std::string& lhs, const String& rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const char* lhs, const String& rhs) { + size_t lhs_size = std::strlen(lhs); + size_t rhs_size = rhs.size(); + return String::Concat(lhs, lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const String& lhs, const char* rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = std::strlen(rhs); + return String::Concat(lhs.data(), lhs_size, rhs, rhs_size); +} + +// Overload < operator +inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; } + +inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; } + +// Overload > operator +inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; } + +inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; } + +// Overload <= operator +inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } + +inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; } + +// Overload >= operator +inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; } + +inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(rhs) <= 0; } + +// Overload == operator +inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; } + +inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; } + +// Overload != operator +inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; } + +inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; } + +inline std::ostream& operator<<(std::ostream& out, const String& input) { + out.write(input.data(), input.size()); + return out; +} + +inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { + if (lhs == rhs && lhs_count == rhs_count) return 0; + + for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { + if (lhs[i] < rhs[i]) return -1; + if (lhs[i] > rhs[i]) return 1; + } + if (lhs_count < rhs_count) { + return -1; + } else if (lhs_count > rhs_count) { + return 1; + } else { + return 0; + } +} + +inline size_t ObjectHash::operator()(const ObjectRef& a) const { + if (const auto* str = a.as()) { + return String::HashBytes(str->data, str->size); + } + return ObjectPtrHash()(a); +} + +inline bool ObjectEqual::operator()(const ObjectRef& a, const ObjectRef& b) const { + if (a.same_as(b)) { + return true; + } + if (const auto* str_a = a.as()) { + if (const auto* str_b = b.as()) { + return String::memncmp(str_a->data, str_b->data, str_a->size, str_b->size) == 0; + } + } + return false; +} +} // namespace runtime + +// expose the functions to the root namespace. +using runtime::String; +using runtime::StringObj; +} // namespace tvm + +namespace std { + +template <> +struct hash<::tvm::runtime::String> { + std::size_t operator()(const ::tvm::runtime::String& str) const { + return ::tvm::runtime::String::HashBytes(str.data(), str.size()); + } +}; +} // namespace std + +#endif // TVM_RUNTIME_CONTAINER_STRING_H_ diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index ada9b74503bc..bfc681e24418 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -25,7 +25,9 @@ #define TVM_RUNTIME_NDARRAY_H_ #include -#include +#include +#include +#include #include #include #include diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 58bd2859c10a..3e8f23b755f9 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -25,7 +25,7 @@ #define TVM_RUNTIME_PACKED_FUNC_H_ #include -#include +#include #include #include #include diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index e0fabfc5d8aa..2cdd180730ec 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -24,7 +24,8 @@ #ifndef TVM_RUNTIME_VM_EXECUTABLE_H_ #define TVM_RUNTIME_VM_EXECUTABLE_H_ -#include +#include +#include #include #include #include diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 15de1df98a78..58c6ee037fb5 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_VM_VM_H_ #define TVM_RUNTIME_VM_VM_H_ -#include +#include #include #include #include diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 401ba102c2f4..85677a726574 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -25,7 +25,6 @@ #define TVM_TE_TENSOR_H_ #include -#include #include #include diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 83f228da9475..a01d69b372d2 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -25,7 +25,8 @@ #define TVM_TIR_BUFFER_H_ #include -#include +#include +#include #include #include diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index e1d097474dd9..40d66a2d8357 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -29,7 +29,9 @@ #include #include #include -#include +#include +#include +#include #include #include #include diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h index 963458ccee4a..6b5d6c48ddd0 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tir/op_attr_types.h @@ -29,7 +29,7 @@ #define TVM_TIR_OP_ATTR_TYPES_H_ #include -#include +#include #include namespace tvm { diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 8273f9912a57..a6681f0b9941 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -27,7 +27,6 @@ #define TVM_TIR_STMT_FUNCTOR_H_ #include -#include #include #include #include diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7c304727080e..ff6f409db483 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 203520802091..caddf0efcc77 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -157,39 +157,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "GlobalVar(" << node->name_hint << ")"; }); -// Container printer -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '['; - for (size_t i = 0; i < op->size(); ++i) { - if (i != 0) { - p->stream << ", "; - } - p->Print(op->at(i)); - } - p->stream << ']'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '{'; - for (auto it = op->begin(); it != op->end(); ++it) { - if (it != op->begin()) { - p->stream << ", "; - } - if (it->first->IsInstance()) { - p->stream << '\"' << Downcast(it->first) << "\": "; - } else { - p->Print(it->first); - p->stream << ": "; - } - p->Print(it->second); - } - p->stream << '}'; - }); - TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) { std::stringstream ss; ss << ref; diff --git a/src/ir/op.cc b/src/ir/op.cc index 861545e6b959..fac15a7daad4 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -23,7 +23,6 @@ */ #include #include -#include #include #include #include diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 7760334af44c..9537ef532b44 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include #include diff --git a/src/node/attr_registry.h b/src/node/attr_registry.h index f84be1467453..050f9e5b2845 100644 --- a/src/node/attr_registry.h +++ b/src/node/attr_registry.h @@ -25,7 +25,6 @@ #define TVM_NODE_ATTR_REGISTRY_H_ #include -#include #include #include diff --git a/src/node/container_printing.cc b/src/node/container_printing.cc new file mode 100644 index 000000000000..7b972966bef8 --- /dev/null +++ b/src/node/container_printing.cc @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Printer implementation for containers + * \file node/container_printint.cc + */ +#include +#include +#include + +namespace tvm { + +// Container printer +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '['; + for (size_t i = 0; i < op->size(); ++i) { + if (i != 0) { + p->stream << ", "; + } + p->Print(op->at(i)); + } + p->stream << ']'; + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '{'; + for (auto it = op->begin(); it != op->end(); ++it) { + if (it != op->begin()) { + p->stream << ", "; + } + if (it->first->IsInstance()) { + p->stream << '\"' << Downcast(it->first) << "\": "; + } else { + p->Print(it->first); + p->stream << ": "; + } + p->Print(it->second); + } + p->stream << '}'; + }); + +} // namespace tvm diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 79a53aa26440..a7c3493e7feb 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include namespace tvm { diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 75f03fbc7954..94dfda556cc9 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -26,7 +26,6 @@ #include #include #include -#include #include #include #include diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 05327b1ca303..f5344ab9126e 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include diff --git a/src/parser/op_table.h b/src/parser/op_table.h index 050904f23280..28c9cd7fc05f 100644 --- a/src/parser/op_table.h +++ b/src/parser/op_table.h @@ -28,7 +28,6 @@ #define TVM_PARSER_OP_TABLE_H_ #include -#include #include #include diff --git a/src/parser/span_check.h b/src/parser/span_check.h index ab71d30a54f5..0074c66d61f4 100644 --- a/src/parser/span_check.h +++ b/src/parser/span_check.h @@ -29,7 +29,6 @@ #include #include #include -#include #include #include diff --git a/src/parser/token.h b/src/parser/token.h index 1133483fa8f8..31e974355e4b 100644 --- a/src/parser/token.h +++ b/src/parser/token.h @@ -26,7 +26,6 @@ #define TVM_PARSER_TOKEN_H_ #include -#include #include #include diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index 5e71794cc7fb..0f407cef52c6 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -25,7 +25,6 @@ #define TVM_PARSER_TOKENIZER_H_ #include -#include #include #include diff --git a/src/printer/meta_data.h b/src/printer/meta_data.h index f76c32d353cf..b2e245bd5b45 100644 --- a/src/printer/meta_data.h +++ b/src/printer/meta_data.h @@ -25,7 +25,6 @@ #define TVM_PRINTER_META_DATA_H_ #include -#include #include #include diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index 85a9c51a2fa8..840878390018 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -21,7 +21,6 @@ #include #include -#include #include #include diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h index d225cb8ae82a..2e4eec23f733 100644 --- a/src/relay/analysis/annotated_region_set.h +++ b/src/relay/analysis/annotated_region_set.h @@ -33,7 +33,6 @@ #include #include #include -#include #include #include diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc index 8dd6819e0e8c..35813f67d094 100644 --- a/src/relay/analysis/context_analysis.cc +++ b/src/relay/analysis/context_analysis.cc @@ -59,7 +59,6 @@ #include #include #include -#include #include namespace tvm { diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 117a478f9ebc..f72f3bd73557 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -31,7 +31,6 @@ #include #include #include -#include #include #include #include diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index b81fd14b99c2..32eecec25b06 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -27,7 +27,6 @@ #include #include #include -#include #include #include diff --git a/src/relay/backend/contrib/codegen_json/codegen_json.h b/src/relay/backend/contrib/codegen_json/codegen_json.h index 192e09140375..4966f3f01c7d 100644 --- a/src/relay/backend/contrib/codegen_json/codegen_json.h +++ b/src/relay/backend/contrib/codegen_json/codegen_json.h @@ -27,7 +27,6 @@ #include #include #include -#include #include #include diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index c9a58282d13e..e96255e976e9 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -393,7 +393,6 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; - code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; // dnnl_kernel file is saved under src/runtime/contrib/dnnl so that we don't diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 38cb763883b7..1ac800f357b0 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -27,7 +27,6 @@ #include #include #include -#include #include diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index e365dca3860f..b12e25a425b6 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -27,7 +27,6 @@ #include #include #include -#include #include "pass_utils.h" diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index fe5f547449ad..57603035b848 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -27,7 +27,6 @@ #include #include #include -#include #include #include diff --git a/src/relay/transforms/label_ops.cc b/src/relay/transforms/label_ops.cc index e0d3892a8d01..861342b03a76 100644 --- a/src/relay/transforms/label_ops.cc +++ b/src/relay/transforms/label_ops.cc @@ -19,7 +19,6 @@ #include #include #include -#include namespace tvm { namespace relay { diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 94891c3c98ea..1dda0d5cf429 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -35,7 +35,6 @@ #include #include #include -#include #include #include diff --git a/src/runtime/container.cc b/src/runtime/container.cc index 3d9b1481f6e6..9d648dcb9a5f 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -21,7 +21,11 @@ * \file src/runtime/container.cc * \brief Implementations of common containers. */ -#include +#include +#include +#include +#include +#include #include #include #include @@ -29,6 +33,42 @@ namespace tvm { namespace runtime { +// Array +TVM_REGISTER_OBJECT_TYPE(ArrayNode); + +TVM_REGISTER_GLOBAL("runtime.Array").set_body([](TVMArgs args, TVMRetValue* ret) { + std::vector data; + for (int i = 0; i < args.size(); ++i) { + if (args[i].type_code() != kTVMNullptr) { + data.push_back(args[i].operator ObjectRef()); + } else { + data.push_back(ObjectRef(nullptr)); + } + } + *ret = Array(data); +}); + +TVM_REGISTER_GLOBAL("runtime.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { + int64_t i = args[1]; + ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + ICHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); + ICHECK_LT(static_cast(i), n->size()) << "out of bound of array"; + *ret = n->at(i); +}); + +TVM_REGISTER_GLOBAL("runtime.ArraySize").set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + ICHECK(ptr->IsInstance()); + *ret = static_cast(static_cast(ptr)->size()); +}); + +// ADT + +TVM_REGISTER_OBJECT_TYPE(ADTObj); + TVM_REGISTER_GLOBAL("runtime.GetADTTag").set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); @@ -67,6 +107,10 @@ TVM_REGISTER_GLOBAL("runtime.ADT").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = ADT(tag, fields); }); +// String + +TVM_REGISTER_OBJECT_TYPE(StringObj); + TVM_REGISTER_GLOBAL("runtime.String").set_body_typed([](std::string str) { return String(std::move(str)); }); @@ -75,40 +119,7 @@ TVM_REGISTER_GLOBAL("runtime.GetFFIString").set_body_typed([](String str) { return std::string(str); }); -TVM_REGISTER_OBJECT_TYPE(ADTObj); -TVM_REGISTER_OBJECT_TYPE(StringObj); -TVM_REGISTER_OBJECT_TYPE(ClosureObj); - -TVM_REGISTER_OBJECT_TYPE(ArrayNode); - -TVM_REGISTER_GLOBAL("runtime.Array").set_body([](TVMArgs args, TVMRetValue* ret) { - std::vector data; - for (int i = 0; i < args.size(); ++i) { - if (args[i].type_code() != kTVMNullptr) { - data.push_back(args[i].operator ObjectRef()); - } else { - data.push_back(ObjectRef(nullptr)); - } - } - *ret = Array(data); -}); - -TVM_REGISTER_GLOBAL("runtime.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { - int64_t i = args[1]; - ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - ICHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - ICHECK_LT(static_cast(i), n->size()) << "out of bound of array"; - *ret = n->at(i); -}); - -TVM_REGISTER_GLOBAL("runtime.ArraySize").set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - ICHECK(ptr->IsInstance()); - *ret = static_cast(static_cast(ptr)->size()); -}); +// Map TVM_REGISTER_OBJECT_TYPE(MapNode); @@ -174,5 +185,7 @@ TVM_REGISTER_GLOBAL("runtime.MapItems").set_body([](TVMArgs args, TVMRetValue* r TVM_DLL constexpr uint64_t DenseMapNode::kNextProbeLocation[]; #endif +TVM_REGISTER_OBJECT_TYPE(ClosureObj); + } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 55f16635b9e6..1735d8569215 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -25,7 +25,6 @@ #ifndef TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_ #define TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_ -#include #include #include diff --git a/src/runtime/contrib/onnx/onnx_module.cc b/src/runtime/contrib/onnx/onnx_module.cc index b235d63dbc58..8732b700a218 100644 --- a/src/runtime/contrib/onnx/onnx_module.cc +++ b/src/runtime/contrib/onnx/onnx_module.cc @@ -21,7 +21,6 @@ * \file onnx_module.cc * \brief ONNX Module without runtime support */ -#include #include #include diff --git a/src/runtime/file_utils.h b/src/runtime/file_utils.h index 718d10d5df70..4e7f158bb04f 100644 --- a/src/runtime/file_utils.h +++ b/src/runtime/file_utils.h @@ -24,7 +24,8 @@ #ifndef TVM_RUNTIME_FILE_UTILS_H_ #define TVM_RUNTIME_FILE_UTILS_H_ -#include +#include +#include #include #include diff --git a/src/runtime/graph_executor/debug/graph_executor_debug.cc b/src/runtime/graph_executor/debug/graph_executor_debug.cc index 5736462a648d..1ea01b19e8aa 100644 --- a/src/runtime/graph_executor/debug/graph_executor_debug.cc +++ b/src/runtime/graph_executor/debug/graph_executor_debug.cc @@ -20,7 +20,7 @@ /*! * \file graph_executor_debug.cc */ -#include +#include #include #include #include diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index 584aafe3410b..ad5b99e06c7e 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -22,7 +22,8 @@ */ #include "graph_executor.h" -#include +#include +#include #include #include #include diff --git a/src/runtime/graph_executor/graph_executor_factory.cc b/src/runtime/graph_executor/graph_executor_factory.cc index 8ea21cabf519..a13fbd860d43 100644 --- a/src/runtime/graph_executor/graph_executor_factory.cc +++ b/src/runtime/graph_executor/graph_executor_factory.cc @@ -24,7 +24,7 @@ #include "./graph_executor_factory.h" -#include +#include #include #include diff --git a/src/runtime/metadata_module.cc b/src/runtime/metadata_module.cc index 4a1d89ce1a1f..7cb986bba62c 100644 --- a/src/runtime/metadata_module.cc +++ b/src/runtime/metadata_module.cc @@ -27,7 +27,8 @@ * code and metadata significantly reduces the efforts for handling external * codegen and runtimes. */ -#include +#include +#include #include #include #include diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 7db84862604f..7272269680c5 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -21,7 +21,7 @@ * \file rpc_module.cc * \brief RPC runtime module. */ -#include +#include #include #include #include diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 4e7fe3196d45..1456fc719113 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -21,7 +21,6 @@ * \file rpc_socket_impl.cc * \brief Socket based RPC implementation. */ -#include #include #include diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 0a7795d600fe..a7d65944d581 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -24,6 +24,7 @@ #include "vm.h" +#include #include #include diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 17a66e419316..c8d0d6bb1fe6 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -23,7 +23,7 @@ */ #include -#include +#include #include #include #include diff --git a/src/support/array.h b/src/support/array.h index 12d76d18db21..2cf416c471ec 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -18,7 +18,7 @@ */ #ifndef TVM_SUPPORT_ARRAY_H_ #define TVM_SUPPORT_ARRAY_H_ -#include +#include #include diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index ea3a22e8ab01..4b5dc9080df1 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -#include #include #include diff --git a/src/support/utils.h b/src/support/utils.h index 075351760686..d807c5b8bb63 100644 --- a/src/support/utils.h +++ b/src/support/utils.h @@ -32,7 +32,7 @@ #endif // __hexagon__ #endif // _WIN32 -#include +#include #include #include diff --git a/src/target/build_common.h b/src/target/build_common.h index 1816c3ac2650..d2fe6468eef8 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -25,7 +25,6 @@ #define TVM_TARGET_BUILD_COMMON_H_ #include -#include #include #include #include diff --git a/src/target/codegen.cc b/src/target/codegen.cc index cf400d90747b..5a4aa39f01b4 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include diff --git a/src/target/func_registry_generator.h b/src/target/func_registry_generator.h index fb5964859352..8d2af305a0e4 100644 --- a/src/target/func_registry_generator.h +++ b/src/target/func_registry_generator.h @@ -24,7 +24,8 @@ #ifndef TVM_TARGET_FUNC_REGISTRY_GENERATOR_H_ #define TVM_TARGET_FUNC_REGISTRY_GENERATOR_H_ -#include +#include +#include #include #include diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc index 5dbceec32ed7..42957152ea12 100644 --- a/src/target/generic_func.cc +++ b/src/target/generic_func.cc @@ -22,7 +22,6 @@ #include #include #include -#include #include #include #include diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index e56a6de6d914..d5fcfab6d889 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -27,7 +27,6 @@ #include #include -#include #include #include #include diff --git a/src/target/llvm/codegen_params.h b/src/target/llvm/codegen_params.h index 771bc201f7aa..f5fd21ff326d 100644 --- a/src/target/llvm/codegen_params.h +++ b/src/target/llvm/codegen_params.h @@ -24,7 +24,6 @@ #ifndef TVM_TARGET_LLVM_CODEGEN_PARAMS_H_ #define TVM_TARGET_LLVM_CODEGEN_PARAMS_H_ -#include #include #include "llvm_common.h" diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h index 1791a5574c11..b967c7ad44e0 100644 --- a/src/target/llvm/llvm_common.h +++ b/src/target/llvm/llvm_common.h @@ -37,7 +37,6 @@ #include #include #include -#include #if TVM_LLVM_VERSION >= 100 #include #include @@ -78,6 +77,7 @@ #include #include #include +#include #include #include diff --git a/src/target/llvm/llvm_module.h b/src/target/llvm/llvm_module.h index 3eab00c643e5..6b05d4bdf2d5 100644 --- a/src/target/llvm/llvm_module.h +++ b/src/target/llvm/llvm_module.h @@ -25,7 +25,6 @@ #ifndef TVM_TARGET_LLVM_LLVM_MODULE_H_ #define TVM_TARGET_LLVM_LLVM_MODULE_H_ -#include #include #include diff --git a/src/target/metadata_module.h b/src/target/metadata_module.h index add05ba52692..9311ee78ca6a 100644 --- a/src/target/metadata_module.h +++ b/src/target/metadata_module.h @@ -25,7 +25,6 @@ #ifndef TVM_TARGET_METADATA_MODULE_H_ #define TVM_TARGET_METADATA_MODULE_H_ -#include #include #include #include diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 76e6a9bc7197..ae451f39f89b 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -25,7 +25,6 @@ #define TVM_TARGET_SOURCE_CODEGEN_C_H_ #include -#include #include #include #include diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 03fef4709b5e..2d93989730c7 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -22,7 +22,6 @@ */ #include "codegen_c_host.h" -#include #include #include #include diff --git a/src/target/source/source_module.h b/src/target/source/source_module.h index 6226ba2f22b3..8ed08048cf2f 100644 --- a/src/target/source/source_module.h +++ b/src/target/source/source_module.h @@ -25,7 +25,6 @@ #ifndef TVM_TARGET_SOURCE_SOURCE_MODULE_H_ #define TVM_TARGET_SOURCE_SOURCE_MODULE_H_ -#include #include #include diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index dc625b6a928d..d8f0f8e90238 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -23,7 +23,6 @@ */ #include "codegen_spirv.h" -#include #include #include #include diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index 0dd96e07ed96..402e3291975f 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -23,7 +23,6 @@ #include "codegen_stackvm.h" #include -#include #include #include #include diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 32cc51039be0..5c59961fe011 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -36,7 +36,6 @@ * - Add annotation of extern buffers using the buffer_map field * in the PrimFunc type. */ -#include #include #include #include diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 0cc0086897d8..ee52a6fc0988 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -20,7 +20,6 @@ /*! * \file make_packed_api.cc Lower PrimFunc to use the packed function API. */ -#include #include #include #include diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 154d0bfa5787..6e8793fbd367 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -20,7 +20,6 @@ /*! * \file make_unpacked_api.cc Lower PrimFunc to a standard C function API. */ -#include #include #include #include diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 921c7ad79509..f01d98707586 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -22,7 +22,6 @@ * \brief Split device function from host. */ #include -#include #include #include #include diff --git a/tests/cpp/auto_scheduler_test.cc b/tests/cpp/auto_scheduler_test.cc index 5e4533733d2e..16dfd56a69ea 100644 --- a/tests/cpp/auto_scheduler_test.cc +++ b/tests/cpp/auto_scheduler_test.cc @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 63819308a666..7d1fa790146e 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -19,7 +19,10 @@ #include #include -#include +#include +#include +#include +#include #include #include diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index cf22577a791a..f993f9605c91 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -19,7 +19,6 @@ #include #include -#include #include #include #include diff --git a/tests/python/unittest/test_runtime_container.py b/tests/python/unittest/test_runtime_container.py index 9d4255c86b5e..39fd575ff6d8 100644 --- a/tests/python/unittest/test_runtime_container.py +++ b/tests/python/unittest/test_runtime_container.py @@ -16,6 +16,7 @@ # under the License. import numpy as np +import random import tvm import tvm.testing import pickle diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index 77ce6be66e63..3054bd0d7109 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -31,7 +31,6 @@ #define DMLC_USE_LOGGING_LIBRARY #include -#include #include #include #include