From 33082e0032fb57b0516ad7e3eabd11fe0203437e Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 22 Feb 2022 09:34:23 -0800 Subject: [PATCH] [runtime] Add Metadata classes for AOTExecutor (#10282) * Add new Metadata classes and base implementation. * These were autogenerated in the original PR, but checking them in as plain code until we can revisit the auto-generator approach. * address masa comments * Add documentation per Manupa's comments, and move kMetadataVersion namespace. * remove get_name function, used for debugging * clang-format --- include/tvm/runtime/metadata.h | 160 +++++++++++++++++++ include/tvm/runtime/metadata_base.h | 198 +++++++++++++++++++++++ include/tvm/support/span.h | 103 ++++++++++++ src/runtime/metadata.cc | 56 +++++++ src/target/metadata.cc | 47 ++++++ src/target/metadata.h | 173 ++++++++++++++++++++ tests/cpp/aot_metadata_test.cc | 236 ++++++++++++++++++++++++++++ 7 files changed, 973 insertions(+) create mode 100644 include/tvm/runtime/metadata.h create mode 100644 include/tvm/runtime/metadata_base.h create mode 100644 include/tvm/support/span.h create mode 100644 src/runtime/metadata.cc create mode 100644 src/target/metadata.cc create mode 100644 src/target/metadata.h create mode 100644 tests/cpp/aot_metadata_test.cc diff --git a/include/tvm/runtime/metadata.h b/include/tvm/runtime/metadata.h new file mode 100644 index 000000000..b716d41c5 --- /dev/null +++ b/include/tvm/runtime/metadata.h @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/metadata.h + * \brief Defines types which can be used in Metadata. + */ +#ifndef TVM_RUNTIME_METADATA_H_ +#define TVM_RUNTIME_METADATA_H_ + +#include +#ifdef __cplusplus +#include +#include +#include +#endif +#include +#ifdef __cplusplus +#include +#endif +#include + +// Version number recorded in emitted artifacts for runtime checking. +#define TVM_METADATA_VERSION 1 + +namespace tvm { +namespace runtime { +namespace metadata { +/*! + * \brief Version of metadata emitted and understood by this compiler/runtime. + * Should be populated into the `version` field of all TVMMetadata. + */ +static const constexpr int64_t kMetadataVersion = TVM_METADATA_VERSION; +} // namespace metadata +} // namespace runtime +} // namespace tvm + +#ifdef __cplusplus +extern "C" { +#endif + +/*! + * \brief Top-level metadata structure. Holds all other metadata types. + */ +struct TVMMetadata { + /*! \brief Version identifier for this metadata. */ + int64_t version; + /*! \brief Inputs to the AOT run_model function. + * The order of the elements is the same as in the arguments to run_model. That is to say, + * this array specifies the first `num_inputs` arguments to run_model. + */ + const struct TVMTensorInfo* inputs; + /*! \brief Number of elements in `inputs` array. */ + int64_t num_inputs; + /*! \brief Outputs of the AOT run_model function. + * The order of the elements is the same as in the arguments to run_model. That is to say, + * this array specifies the last `num_outputs` arguments to run_model. + */ + const struct TVMTensorInfo* outputs; + /*! \brief Number of elements in `outputs` array. */ + int64_t num_outputs; + /*! \brief Name of the model, as passed to tvm.relay.build. */ + const char* mod_name; +}; + +/*! + * \brief Describes one tensor argument to `run_model`. + * NOTE: while TIR allows for other types of arguments, such as scalars, the AOT run_model + * function does not currently accept these. Therefore it's not possible to express those + * in this metadata. A future patch may modify this. + */ +struct TVMTensorInfo { + /*! \brief Name of the tensor, as specified in the Relay program. */ + const char* name; + /*! \brief Shape of the tensor. */ + const int64_t* shape; + /*! \brief Rank of this tensor. */ + int64_t num_shape; + /*! \brief Data type of one element of this tensor. */ + DLDataType dtype; +}; +#ifdef __cplusplus +} // extern "C" +#include +namespace tvm { +namespace runtime { +namespace metadata { + +class Metadata; +class TensorInfo; + +class MetadataNode : public MetadataBaseNode { + public: + explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {} + static constexpr const char* _type_key = "metadata.MetadataNode"; + inline int64_t version() const { return int64_t(data_->version); } + inline int64_t num_inputs() const { return data_->num_inputs; } + ArrayAccessor inputs(); + inline int64_t num_outputs() const { return data_->num_outputs; } + ArrayAccessor outputs(); + inline ::tvm::runtime::String mod_name() const { return ::tvm::runtime::String(data_->mod_name); } + const struct ::TVMMetadata* data() const { return data_; } + TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, MetadataBaseNode); + + private: + const struct ::TVMMetadata* data_; +}; + +class Metadata : public MetadataBase { + public: + explicit Metadata(const struct ::TVMMetadata* data); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Metadata, MetadataBase, MetadataNode); +}; + +class TensorInfoNode : public MetadataBaseNode { + public: + explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {} + static constexpr const char* _type_key = "metadata.TensorInfoNode"; + inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); } + inline int64_t num_shape() const { return data_->num_shape; } + inline ::tvm::support::Span shape() const { + return ::tvm::support::Span(data_->shape, + data_->shape + data_->num_shape); + } + inline ::tvm::runtime::DataType dtype() const { return ::tvm::runtime::DataType(data_->dtype); } + const struct ::TVMTensorInfo* data() const { return data_; } + TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, MetadataBaseNode); + + private: + const struct ::TVMTensorInfo* data_; +}; + +class TensorInfo : public MetadataBase { + public: + explicit TensorInfo(const struct ::TVMTensorInfo* data); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorInfo, MetadataBase, TensorInfoNode); +}; + +} // namespace metadata +} // namespace runtime +} // namespace tvm +#endif // defined(__cplusplus) + +#endif // TVM_RUNTIME_METADATA_H_ diff --git a/include/tvm/runtime/metadata_base.h b/include/tvm/runtime/metadata_base.h new file mode 100644 index 000000000..96743199f --- /dev/null +++ b/include/tvm/runtime/metadata_base.h @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/metadata_base.h + * \brief Defines types which can be used in Metadata. + */ +#ifndef TVM_RUNTIME_METADATA_BASE_H_ +#define TVM_RUNTIME_METADATA_BASE_H_ + +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace metadata { + +/*! + * \brief Common base class for all Metadata. + * + * This class is used in the visitor classes as a internal check to ensure that verify that all + * parts of the Metadata struct used in codegen are Metadata objects. + */ +class MetadataBaseNode : public ::tvm::runtime::Object { + public: + static constexpr const char* _type_key = "metadata.MetadataBaseNode"; + TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object); +}; + +/*! \brief Reference class for the common MetadataBaseNode class. */ +class MetadataBase : public ::tvm::runtime::ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataBase, ::tvm::runtime::ObjectRef, MetadataBaseNode); +}; + +template +class ArrayAccessor; + +/*! \brief An iterator implementation that lazily instantiates the C++ wrapping Metadata class. */ +template +class ArrayIterator { + public: + ArrayIterator(size_t index, const ArrayAccessor* parent) + : index_{index}, parent_{parent} {} + + inline Ref operator*() { return (*parent_)[index_]; } + + inline ArrayIterator& operator++() { + if (index_ < parent_->size()) { + index_++; + } + + return *this; + } + + inline bool operator==(const ArrayIterator& other) const { + return parent_ == other.parent_ && index_ == other.index_; + } + + inline bool operator!=(const ArrayIterator& other) const { return !operator==(other); } + + private: + size_t index_; + const ArrayAccessor* parent_; +}; + +/*! \brief A span-like class which permits access to Array fields with complex elements. + * These array fields should be accessed from C++ using the Metadata wrapper classes. This class + * lazily instantiates those wrappers as they are accessed. + */ +template +class ArrayAccessor { + public: + using value_type = Ref; + using iterator = ArrayIterator; + using const_iterator = iterator; + + template ::value>::type> + ArrayAccessor(const C* data, size_t num_data) : data_{data}, num_data_{num_data} {} + + inline size_t size() const { return num_data_; } + + inline Ref operator[](size_t index) const { + if (index >= num_data_) { + throw std::runtime_error("Index out of range"); + } + + return Ref(&data_[index]); + } + + inline ArrayIterator begin() const { return ArrayIterator{0, this}; } + + inline ArrayIterator end() const { return ArrayIterator{num_data_, this}; } + + private: + const C* data_; + size_t num_data_; +}; + +/*! \brief A specialization of ArrayAccessor for String. + * This class is needed because the String constructor signature is different from the typical + * Metadata subclass. + */ +template <> +class ArrayAccessor { + public: + using value_type = ::tvm::runtime::String; + using iterator = ArrayIterator; + using const_iterator = iterator; + + ArrayAccessor(const char** data, size_t num_data) : data_{data}, num_data_{num_data} {} + + inline size_t size() const { return num_data_; } + + inline ::tvm::runtime::String operator[](size_t index) const { + if (index >= num_data_) { + throw std::runtime_error("Index out of range"); + } + return ::tvm::runtime::String(data_[index]); + } + + inline ArrayIterator begin() const { + return ArrayIterator{0, this}; + } + + inline ArrayIterator end() const { + return ArrayIterator{num_data_, this}; + } + + private: + const char** data_; + size_t num_data_; +}; + +/*! \brief Enumerates the primitive types which can be part of a Metadata instance. + * + * These are separate from TIR DataType because TIR does not model structs. + */ +enum MetadataTypeIndex : uint8_t { + kUint64 = 0, + kInt64 = 1, + kBool = 2, + kString = 3, + kHandle = 4, + kMetadata = 5, +}; + +/*! \brief Container for arrays in the metadata. + * + * Type information is needed when emitting arrays. This container augments the data field with + * the necessary typing information. + */ +class MetadataArrayNode : public MetadataBaseNode { + public: + MetadataArrayNode(Array array, MetadataTypeIndex type_index, const char* struct_name) + : array(::std::move(array)), type_index{type_index}, struct_name{struct_name} {} + + Array array; + MetadataTypeIndex type_index; + const char* struct_name; + static constexpr const char* _type_key = "metadata.MetadataArrayNode"; + TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode); +}; + +/*! \brief Reference class for MetadataArray. */ +class MetadataArray : public MetadataBase { + public: + MetadataArray(Array array, MetadataTypeIndex type_index, const char* struct_name); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode); +}; + +} // namespace metadata +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_METADATA_BASE_H_ diff --git a/include/tvm/support/span.h b/include/tvm/support/span.h new file mode 100644 index 000000000..faa849c4a --- /dev/null +++ b/include/tvm/support/span.h @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file tvm/support/span.h + * \brief Reimplementation of part of C++-20 style span. + */ +#ifndef TVM_SUPPORT_SPAN_H_ +#define TVM_SUPPORT_SPAN_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace support { + +/*! + * \brief A partial implementation of the C++20 std::span. + * + * At the time of writing, TVM must compile against C++14. + */ +template +class Span { + public: + using value_type = W; + using const_W = typename ::std::add_const::type; + + template + class iterator_base : public std::iterator { + public: + inline iterator_base(T* ptr, T* end) : ptr_{ptr}, end_{end} { CHECK_GE(end, ptr); } + + inline W1 operator*() { return W1(*ptr_); } + + inline iterator_base& operator++() { + if (ptr_ != end_) ptr_++; + return *this; + } + + inline bool operator==(iterator_base other) { + return ptr_ == other.ptr_ && end_ == other.end_; + } + + inline bool operator!=(iterator_base other) { return !(*this == other); } + + template ::value> > + inline operator iterator_base() const { + return iterator_base(ptr_, end_); + } + + private: + T* ptr_; + T* end_; + }; + + using iterator = iterator_base; + using const_iterator = iterator_base; + + inline Span(T* begin, int num_elements) : begin_{begin}, end_{begin + num_elements} {} + inline Span(T* begin, T* end) : begin_{begin}, end_{end} {} + + inline iterator begin() const { return iterator(begin_, end_); } + + inline iterator end() const { return iterator(end_, end_); } + + size_t size() const { return end_ - begin_; } + + inline W operator[](int i) { + T* to_return = begin_ + i; + ICHECK_LT(to_return, end_) << "Span access out of bounds: " << i; + return W(*to_return); + } + + inline operator std::vector() { return std::vector(begin(), end()); } + + protected: + T* begin_; + T* end_; +}; + +} // namespace support +} // namespace tvm + +#endif // TVM_SUPPORT_SPAN_H_ diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc new file mode 100644 index 000000000..7ca333b06 --- /dev/null +++ b/src/runtime/metadata.cc @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file metadata.cc + * \brief Implementations of the runtime component of Metadata. + */ + +#include + +namespace tvm { +namespace runtime { +namespace metadata { + +ArrayAccessor MetadataNode::inputs() { + return ArrayAccessor(data_->inputs, data_->num_inputs); +} +ArrayAccessor MetadataNode::outputs() { + return ArrayAccessor(data_->outputs, data_->num_outputs); +} + +TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode); + +MetadataArray::MetadataArray(Array array, MetadataTypeIndex type_index, + const char* struct_name) + : MetadataBase{make_object(array, type_index, struct_name)} {} + +TVM_REGISTER_OBJECT_TYPE(MetadataArrayNode); + +Metadata::Metadata(const struct ::TVMMetadata* data) + : MetadataBase{make_object(data)} {} +TVM_REGISTER_OBJECT_TYPE(MetadataNode); + +TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data) + : MetadataBase{make_object(data)} {} +TVM_REGISTER_OBJECT_TYPE(TensorInfoNode); + +} // namespace metadata +} // namespace runtime +} // namespace tvm diff --git a/src/target/metadata.cc b/src/target/metadata.cc new file mode 100644 index 000000000..adf4cba3e --- /dev/null +++ b/src/target/metadata.cc @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file metadata.cc + * \brief Implementations of the compiler extensions for Metadata. + */ + +#include "metadata.h" + +#include + +namespace tvm { +namespace target { +namespace metadata { + +TVM_REGISTER_REFLECTION_VTABLE(VisitableMetadataNode, + ::tvm::detail::ReflectionTrait) + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); + +TVM_REGISTER_REFLECTION_VTABLE(VisitableTensorInfoNode, + ::tvm::detail::ReflectionTrait) + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); + +} // namespace metadata +} // namespace target +} // namespace tvm diff --git a/src/target/metadata.h b/src/target/metadata.h new file mode 100644 index 000000000..2621d5d4e --- /dev/null +++ b/src/target/metadata.h @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/metadata.h + * \brief Extends Metadata for use in the compiler. + */ +#ifndef TVM_TARGET_METADATA_H_ +#define TVM_TARGET_METADATA_H_ + +#include + +#include +#include +#include + +namespace tvm { +namespace target { +namespace metadata { + +/*! + * \brief Subclass of MetadataNode that implements the VisitAttrs reflection method. + * + * This implementation (and other such Visitable subclasses) is compiled into libtvm.so, but not + * libtvm_runtime.so, because reflection is not supported in libtvm_runtime.so over code size + * concerns. It is used during compilation by the generic metadata code-generators. + */ +class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode { + public: + explicit VisitableMetadataNode(const struct ::TVMMetadata* data) : MetadataNode{data} {} + VisitableMetadataNode() : MetadataNode{nullptr} {} + + void VisitAttrs(AttrVisitor* v) { + int64_t version_cpp{version()}; + v->Visit("version", &version_cpp); + auto inputs_array = Array(); + auto inputs_accessor = inputs(); + inputs_array.reserve(num_inputs()); + for (int64_t i = 0; i < num_inputs(); ++i) { + inputs_array.push_back(::tvm::runtime::metadata::TensorInfo{inputs_accessor[i]}); + } + ::tvm::runtime::metadata::MetadataArray inputs_metadata_array{ + inputs_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"}; + v->Visit("inputs", &inputs_metadata_array); + int64_t num_inputs_cpp = num_inputs(); + v->Visit("num_inputs", &num_inputs_cpp); + auto outputs_array = Array(); + auto outputs_accessor = outputs(); + outputs_array.reserve(num_outputs()); + for (int64_t i = 0; i < num_outputs(); ++i) { + outputs_array.push_back(::tvm::runtime::metadata::TensorInfo{outputs_accessor[i]}); + } + ::tvm::runtime::metadata::MetadataArray outputs_metadata_array{ + outputs_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"}; + v->Visit("outputs", &outputs_metadata_array); + int64_t num_outputs_cpp = num_outputs(); + v->Visit("num_outputs", &num_outputs_cpp); + ::std::string mod_name_cpp{data()->mod_name}; + v->Visit("mod_name", &mod_name_cpp); + } +}; + +/*! + * \brief Subclass of MetadataNode which also owns the backing C structures. + * + * This class (and other InMemory subclasses) are used during compilation to instantiate Metadata + * instances whose storage lives outside of .rodata. This class exists because the Module returned + * from tvm.relay.build must also be ready to run inference. + */ +class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNode { + public: + InMemoryMetadataNode() + : InMemoryMetadataNode(0 /* version */, {} /* inputs */, {} /* outputs */, + "" /* mod_name */) {} + InMemoryMetadataNode(int64_t version, + const ::std::vector<::tvm::runtime::metadata::TensorInfo>& inputs, + const ::std::vector<::tvm::runtime::metadata::TensorInfo>& outputs, + const ::tvm::runtime::String mod_name) + : VisitableMetadataNode{&storage_}, + inputs_{new struct TVMTensorInfo[inputs.size()]()}, + inputs_objs_{inputs}, + outputs_{new struct TVMTensorInfo[outputs.size()]()}, + outputs_objs_{outputs}, + mod_name_{mod_name}, + storage_{version, nullptr, 0, nullptr, 0, mod_name_.c_str()} { + storage_.inputs = inputs_.get(); + storage_.num_inputs = inputs.size(); + for (unsigned int i = 0; i < inputs.size(); ++i) { + inputs_.get()[i] = *inputs[i]->data(); + } + storage_.outputs = outputs_.get(); + storage_.num_outputs = outputs.size(); + for (unsigned int i = 0; i < outputs.size(); ++i) { + outputs_.get()[i] = *outputs[i]->data(); + } + } + + private: + ::std::unique_ptr inputs_; + std::vector<::tvm::runtime::metadata::TensorInfo> inputs_objs_; + ::std::unique_ptr outputs_; + std::vector<::tvm::runtime::metadata::TensorInfo> outputs_objs_; + ::std::string mod_name_; + struct ::TVMMetadata storage_; +}; + +class VisitableTensorInfoNode : public ::tvm::runtime::metadata::TensorInfoNode { + public: + explicit VisitableTensorInfoNode(const struct ::TVMTensorInfo* data) : TensorInfoNode{data} {} + VisitableTensorInfoNode() : TensorInfoNode{nullptr} {} + + void VisitAttrs(AttrVisitor* v) { + ::std::string name_cpp{data()->name}; + v->Visit("name", &name_cpp); + auto shape_array = Array(); + auto shape_accessor = shape(); + shape_array.reserve(num_shape()); + for (int64_t i = 0; i < num_shape(); ++i) { + shape_array.push_back(::tvm::Integer{static_cast(shape_accessor[i])}); + } + ::tvm::runtime::metadata::MetadataArray shape_metadata_array{ + shape_array, ::tvm::runtime::metadata::MetadataTypeIndex::kInt64, nullptr}; + v->Visit("shape", &shape_metadata_array); + int64_t num_shape_cpp = num_shape(); + v->Visit("num_shape", &num_shape_cpp); + ::tvm::runtime::DataType dtype_cpp{dtype()}; + v->Visit("dtype", &dtype_cpp); + } +}; + +class InMemoryTensorInfoNode : public ::tvm::target::metadata::VisitableTensorInfoNode { + public: + InMemoryTensorInfoNode() : InMemoryTensorInfoNode("", {}, ::tvm::runtime::DataType(0, 0, 0)) {} + InMemoryTensorInfoNode(const ::tvm::runtime::String& name, const ::std::vector& shape, + ::tvm::runtime::DataType dtype) + : VisitableTensorInfoNode{&storage_}, + name_{name}, + shape_{new int64_t[shape.size()]()}, + storage_{name_.c_str(), nullptr, 0, dtype} { + storage_.shape = shape_.get(); + storage_.num_shape = shape.size(); + for (unsigned int i = 0; i < shape.size(); ++i) { + shape_.get()[i] = shape[i]; + } + } + + private: + ::std::string name_; + ::std::unique_ptr shape_; + struct ::TVMTensorInfo storage_; +}; + +} // namespace metadata +} // namespace target +} // namespace tvm + +#endif // TVM_TARGET_METADATA_H_ diff --git a/tests/cpp/aot_metadata_test.cc b/tests/cpp/aot_metadata_test.cc new file mode 100644 index 000000000..730762237 --- /dev/null +++ b/tests/cpp/aot_metadata_test.cc @@ -0,0 +1,236 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../src/target/metadata.h" + +namespace { + +const int64_t kNormalInput1Shape[4] = {1, 5, 5, 3}; +const struct TVMTensorInfo kNormalInputs[2] = { + {"input1", kNormalInput1Shape, 4, DLDataType{1, 2, 3}}, + {"input2", kNormalInput1Shape, 4, DLDataType{2, 3, 4}}}; + +const int64_t kNormalOutput1Shape[3] = {3, 8, 8}; +const struct TVMTensorInfo kNormalOutputs[1] = { + {"output1", kNormalOutput1Shape, 3, DLDataType{3, 4, 5}}}; + +const struct TVMMetadata kNormal = { + TVM_METADATA_VERSION, kNormalInputs, 2, kNormalOutputs, 1, "default", +}; +} // namespace + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::StrEq; +using ::tvm::runtime::Downcast; + +TEST(Metadata, ParseStruct) { + tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal); + EXPECT_THAT(md->version(), Eq(TVM_METADATA_VERSION)); + EXPECT_THAT(md->num_inputs(), Eq(2)); + + auto inputs = md->inputs(); + EXPECT_THAT(inputs.size(), Eq(2)); + + auto input1 = inputs[0]; + EXPECT_THAT(input1->name(), Eq("input1")); + EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3)); + EXPECT_THAT(input1->dtype(), Eq(tvm::runtime::DataType(DLDataType{1, 2, 3}))); + + auto input2 = inputs[1]; + EXPECT_THAT(input2->name(), Eq("input2")); + EXPECT_THAT(input2->shape(), ElementsAre(1, 5, 5, 3)); + EXPECT_THAT(input2->dtype(), Eq(tvm::runtime::DataType(DLDataType{2, 3, 4}))); + + EXPECT_THAT(md->num_outputs(), Eq(1)); + auto outputs = md->outputs(); + EXPECT_THAT(outputs.size(), Eq(1)); + + auto output1 = outputs[0]; + EXPECT_THAT(output1->name(), Eq("output1")); + EXPECT_THAT(output1->shape(), ElementsAre(3, 8, 8)); + EXPECT_THAT(output1->dtype(), Eq(tvm::runtime::DataType(DLDataType{3, 4, 5}))); + + EXPECT_THAT(md->mod_name(), Eq("default")); +} + +class TestVisitor : public tvm::AttrVisitor { + public: + using Element = ::std::tuple<::std::string, ::tvm::runtime::ObjectRef>; + void Visit(const char* key, double* value) final { + keys.push_back(key); + values.push_back(::tvm::FloatImm(::tvm::runtime::DataType(kDLFloat, 64, 1), *value)); + } + void Visit(const char* key, int64_t* value) final { + keys.push_back(key); + values.push_back(::tvm::IntImm(::tvm::runtime::DataType(kDLInt, 64, 1), *value)); + } + void Visit(const char* key, uint64_t* value) final { + keys.push_back(key); + int64_t v; + *(reinterpret_cast(&v)) = *value; + values.push_back(::tvm::IntImm(::tvm::runtime::DataType(kDLUInt, 64, 1), v)); + } + void Visit(const char* key, int* value) final { + keys.push_back(key); + values.push_back(::tvm::IntImm(::tvm::runtime::DataType(kDLInt, 64, 1), *value)); + } + void Visit(const char* key, bool* value) final { + keys.push_back(key); + values.push_back(::tvm::Bool(*value)); + } + void Visit(const char* key, std::string* value) final { + keys.push_back(key); + values.push_back(::tvm::runtime::String(*value)); + } + void Visit(const char* key, tvm::runtime::DataType* value) final { + keys.push_back(key); + values.push_back(::tvm::PrimType(*value)); + } + void Visit(const char* key, tvm::runtime::NDArray* value) final { + keys.push_back(key); + values.push_back(*value); + } + void Visit(const char* key, void** value) final { CHECK(false) << "Do not expect this type"; } + + void Visit(const char* key, ::tvm::runtime::ObjectRef* value) final { + keys.push_back(key); + values.push_back(*value); + } + + std::vector keys; + std::vector<::tvm::runtime::ObjectRef> values; +}; + +TEST(Metadata, Visitor) { + tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal); + TestVisitor v; + ::tvm::ReflectionVTable::Global()->VisitAttrs(md.operator->(), &v); + + EXPECT_THAT(v.keys, ElementsAre(StrEq("version"), StrEq("inputs"), StrEq("num_inputs"), + StrEq("outputs"), StrEq("num_outputs"), StrEq("mod_name"))); + EXPECT_THAT(Downcast(v.values[0])->value, Eq(TVM_METADATA_VERSION)); + + EXPECT_THAT(Downcast(v.values[0])->value, Eq(TVM_METADATA_VERSION)); + + // Just identify the tensor. + auto input_array = Downcast(v.values[1]); + EXPECT_THAT(input_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); + EXPECT_THAT(input_array->struct_name, StrEq("TVMTensorInfo")); + EXPECT_THAT(input_array->array.size(), Eq(2)); + + auto input1 = Downcast(input_array->array[0]); + EXPECT_THAT(input1->name(), StrEq("input1")); + EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3)); + EXPECT_THAT(input1->dtype(), tvm::runtime::DataType(DLDataType{1, 2, 3})); + + auto input2 = Downcast(input_array->array[1]); + EXPECT_THAT(input1->name(), StrEq("input1")); + EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3)); + EXPECT_THAT(input1->dtype(), tvm::runtime::DataType(DLDataType{1, 2, 3})); + + auto num_inputs = Downcast(v.values[2]); + EXPECT_THAT(num_inputs->value, Eq(2)); + + auto output_array = Downcast(v.values[3]); + EXPECT_THAT(output_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata)); + EXPECT_THAT(output_array->struct_name, StrEq("TVMTensorInfo")); + auto output1 = Downcast(output_array->array[0]); + + EXPECT_THAT(output1->name(), Eq("output1")); + + auto num_outputs = Downcast(v.values[4]); + EXPECT_THAT(num_outputs->value, Eq(1)); +} + +using ::tvm::runtime::make_object; +TEST(Metadata, InMemory) { + tvm::runtime::metadata::Metadata md = + tvm::runtime::metadata::Metadata(make_object( + TVM_METADATA_VERSION, + std::vector( + {tvm::runtime::metadata::TensorInfo( + make_object( + tvm::String("Input1"), std::vector{1, 5, 5, 3}, + tvm::runtime::DataType(DLDataType{1, 2, 3}))), + tvm::runtime::metadata::TensorInfo( + make_object( + tvm::String("Input2"), std::vector{1, 5, 5, 3}, + tvm::runtime::DataType(DLDataType{2, 3, 4})))}), + std::vector({tvm::runtime::metadata::TensorInfo( + make_object( + tvm::String("Output1"), std::vector{3, 8, 8}, + tvm::runtime::DataType(DLDataType{3, 4, 5})))}), + "default")); + + auto md_data = md->data(); + EXPECT_THAT(md_data->version, Eq(TVM_METADATA_VERSION)); + EXPECT_THAT(md_data->num_inputs, Eq(2)); + + auto input0 = &md_data->inputs[0]; + EXPECT_THAT(input0->name, StrEq("Input1")); + EXPECT_THAT(std::vector(input0->shape, input0->shape + input0->num_shape), + ElementsAre(1, 5, 5, 3)); + EXPECT_THAT(tvm::runtime::DataType(input0->dtype), + Eq(tvm::runtime::DataType(DLDataType({1, 2, 3})))); + + auto input1 = &md_data->inputs[1]; + EXPECT_THAT(input1->name, StrEq("Input2")); + EXPECT_THAT(std::vector(input1->shape, input1->shape + input1->num_shape), + ElementsAre(1, 5, 5, 3)); + EXPECT_THAT(tvm::runtime::DataType(input1->dtype), + Eq(tvm::runtime::DataType(DLDataType({2, 3, 4})))); + + auto output0 = &md_data->outputs[0]; + EXPECT_THAT(output0->name, StrEq("Output1")); + EXPECT_THAT(std::vector(output0->shape, output0->shape + output0->num_shape), + ElementsAre(3, 8, 8)); + EXPECT_THAT(tvm::runtime::DataType(output0->dtype), + Eq(tvm::runtime::DataType(DLDataType({3, 4, 5})))); + + EXPECT_THAT(md_data->mod_name, StrEq("default")); +} + +TEST(Metadata, ZeroElementLists) { + tvm::runtime::metadata::Metadata md = + tvm::runtime::metadata::Metadata(make_object( + TVM_METADATA_VERSION, std::vector({}), + std::vector({tvm::runtime::metadata::TensorInfo( + make_object( + tvm::String("Output1"), std::vector{}, + tvm::runtime::DataType(DLDataType{3, 4, 5})))}), + "default")); + + EXPECT_THAT(md->data()->num_inputs, Eq(0)); + EXPECT_THAT(md->inputs().size(), Eq(0)); + EXPECT_THAT(md->num_inputs(), Eq(0)); + EXPECT_THAT(md->inputs(), ElementsAre()); + + auto output0 = md->data()->outputs[0]; + EXPECT_THAT(output0.num_shape, Eq(0)); + EXPECT_THAT(md->outputs()[0]->shape().size(), Eq(0)); + EXPECT_THAT(md->outputs()[0]->shape(), ElementsAre()); +}