This repository has been archived by the owner on Apr 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[runtime] Add Metadata classes for AOTExecutor (#10282)
* Add new Metadata classes and base implementation. * These were autogenerated in the original PR, but checking them in as plain code until we can revisit the auto-generator approach. * address masa comments * Add documentation per Manupa's comments, and move kMetadataVersion namespace. * remove get_name function, used for debugging * clang-format
- Loading branch information
Showing
7 changed files
with
973 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm/runtime/metadata.h | ||
* \brief Defines types which can be used in Metadata. | ||
*/ | ||
#ifndef TVM_RUNTIME_METADATA_H_ | ||
#define TVM_RUNTIME_METADATA_H_ | ||
|
||
#include <inttypes.h> | ||
#ifdef __cplusplus | ||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
#endif | ||
#include <tvm/runtime/c_runtime_api.h> | ||
#ifdef __cplusplus | ||
#include <tvm/runtime/metadata_base.h> | ||
#endif | ||
#include <tvm/support/span.h> | ||
|
||
// Version number recorded in emitted artifacts for runtime checking. | ||
#define TVM_METADATA_VERSION 1 | ||
|
||
namespace tvm { | ||
namespace runtime { | ||
namespace metadata { | ||
/*! | ||
* \brief Version of metadata emitted and understood by this compiler/runtime. | ||
* Should be populated into the `version` field of all TVMMetadata. | ||
*/ | ||
static const constexpr int64_t kMetadataVersion = TVM_METADATA_VERSION; | ||
} // namespace metadata | ||
} // namespace runtime | ||
} // namespace tvm | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
/*! | ||
* \brief Top-level metadata structure. Holds all other metadata types. | ||
*/ | ||
struct TVMMetadata { | ||
/*! \brief Version identifier for this metadata. */ | ||
int64_t version; | ||
/*! \brief Inputs to the AOT run_model function. | ||
* The order of the elements is the same as in the arguments to run_model. That is to say, | ||
* this array specifies the first `num_inputs` arguments to run_model. | ||
*/ | ||
const struct TVMTensorInfo* inputs; | ||
/*! \brief Number of elements in `inputs` array. */ | ||
int64_t num_inputs; | ||
/*! \brief Outputs of the AOT run_model function. | ||
* The order of the elements is the same as in the arguments to run_model. That is to say, | ||
* this array specifies the last `num_outputs` arguments to run_model. | ||
*/ | ||
const struct TVMTensorInfo* outputs; | ||
/*! \brief Number of elements in `outputs` array. */ | ||
int64_t num_outputs; | ||
/*! \brief Name of the model, as passed to tvm.relay.build. */ | ||
const char* mod_name; | ||
}; | ||
|
||
/*! | ||
* \brief Describes one tensor argument to `run_model`. | ||
* NOTE: while TIR allows for other types of arguments, such as scalars, the AOT run_model | ||
* function does not currently accept these. Therefore it's not possible to express those | ||
* in this metadata. A future patch may modify this. | ||
*/ | ||
struct TVMTensorInfo { | ||
/*! \brief Name of the tensor, as specified in the Relay program. */ | ||
const char* name; | ||
/*! \brief Shape of the tensor. */ | ||
const int64_t* shape; | ||
/*! \brief Rank of this tensor. */ | ||
int64_t num_shape; | ||
/*! \brief Data type of one element of this tensor. */ | ||
DLDataType dtype; | ||
}; | ||
#ifdef __cplusplus | ||
} // extern "C" | ||
#include <tvm/runtime/object.h> | ||
namespace tvm { | ||
namespace runtime { | ||
namespace metadata { | ||
|
||
class Metadata; | ||
class TensorInfo; | ||
|
||
class MetadataNode : public MetadataBaseNode { | ||
public: | ||
explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {} | ||
static constexpr const char* _type_key = "metadata.MetadataNode"; | ||
inline int64_t version() const { return int64_t(data_->version); } | ||
inline int64_t num_inputs() const { return data_->num_inputs; } | ||
ArrayAccessor<struct TVMTensorInfo, TensorInfo> inputs(); | ||
inline int64_t num_outputs() const { return data_->num_outputs; } | ||
ArrayAccessor<struct TVMTensorInfo, TensorInfo> outputs(); | ||
inline ::tvm::runtime::String mod_name() const { return ::tvm::runtime::String(data_->mod_name); } | ||
const struct ::TVMMetadata* data() const { return data_; } | ||
TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, MetadataBaseNode); | ||
|
||
private: | ||
const struct ::TVMMetadata* data_; | ||
}; | ||
|
||
class Metadata : public MetadataBase { | ||
public: | ||
explicit Metadata(const struct ::TVMMetadata* data); | ||
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Metadata, MetadataBase, MetadataNode); | ||
}; | ||
|
||
class TensorInfoNode : public MetadataBaseNode { | ||
public: | ||
explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {} | ||
static constexpr const char* _type_key = "metadata.TensorInfoNode"; | ||
inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); } | ||
inline int64_t num_shape() const { return data_->num_shape; } | ||
inline ::tvm::support::Span<const int64_t, int64_t> shape() const { | ||
return ::tvm::support::Span<const int64_t, int64_t>(data_->shape, | ||
data_->shape + data_->num_shape); | ||
} | ||
inline ::tvm::runtime::DataType dtype() const { return ::tvm::runtime::DataType(data_->dtype); } | ||
const struct ::TVMTensorInfo* data() const { return data_; } | ||
TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, MetadataBaseNode); | ||
|
||
private: | ||
const struct ::TVMTensorInfo* data_; | ||
}; | ||
|
||
class TensorInfo : public MetadataBase { | ||
public: | ||
explicit TensorInfo(const struct ::TVMTensorInfo* data); | ||
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorInfo, MetadataBase, TensorInfoNode); | ||
}; | ||
|
||
} // namespace metadata | ||
} // namespace runtime | ||
} // namespace tvm | ||
#endif // defined(__cplusplus) | ||
|
||
#endif // TVM_RUNTIME_METADATA_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm/runtime/metadata_base.h | ||
* \brief Defines types which can be used in Metadata. | ||
*/ | ||
#ifndef TVM_RUNTIME_METADATA_BASE_H_ | ||
#define TVM_RUNTIME_METADATA_BASE_H_ | ||
|
||
#include <tvm/ir/expr.h> | ||
#include <tvm/runtime/object.h> | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
namespace tvm { | ||
namespace runtime { | ||
namespace metadata { | ||
|
||
/*! | ||
* \brief Common base class for all Metadata. | ||
* | ||
* This class is used in the visitor classes as a internal check to ensure that verify that all | ||
* parts of the Metadata struct used in codegen are Metadata objects. | ||
*/ | ||
class MetadataBaseNode : public ::tvm::runtime::Object { | ||
public: | ||
static constexpr const char* _type_key = "metadata.MetadataBaseNode"; | ||
TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object); | ||
}; | ||
|
||
/*! \brief Reference class for the common MetadataBaseNode class. */ | ||
class MetadataBase : public ::tvm::runtime::ObjectRef { | ||
public: | ||
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataBase, ::tvm::runtime::ObjectRef, MetadataBaseNode); | ||
}; | ||
|
||
template <typename C, class Ref> | ||
class ArrayAccessor; | ||
|
||
/*! \brief An iterator implementation that lazily instantiates the C++ wrapping Metadata class. */ | ||
template <typename C, class Ref> | ||
class ArrayIterator { | ||
public: | ||
ArrayIterator(size_t index, const ArrayAccessor<C, Ref>* parent) | ||
: index_{index}, parent_{parent} {} | ||
|
||
inline Ref operator*() { return (*parent_)[index_]; } | ||
|
||
inline ArrayIterator<C, Ref>& operator++() { | ||
if (index_ < parent_->size()) { | ||
index_++; | ||
} | ||
|
||
return *this; | ||
} | ||
|
||
inline bool operator==(const ArrayIterator<C, Ref>& other) const { | ||
return parent_ == other.parent_ && index_ == other.index_; | ||
} | ||
|
||
inline bool operator!=(const ArrayIterator<C, Ref>& other) const { return !operator==(other); } | ||
|
||
private: | ||
size_t index_; | ||
const ArrayAccessor<C, Ref>* parent_; | ||
}; | ||
|
||
/*! \brief A span-like class which permits access to Array fields with complex elements. | ||
* These array fields should be accessed from C++ using the Metadata wrapper classes. This class | ||
* lazily instantiates those wrappers as they are accessed. | ||
*/ | ||
template <typename C, class Ref> | ||
class ArrayAccessor { | ||
public: | ||
using value_type = Ref; | ||
using iterator = ArrayIterator<C, Ref>; | ||
using const_iterator = iterator; | ||
|
||
template <typename T = typename std::enable_if<std::is_base_of<ObjectRef, Ref>::value>::type> | ||
ArrayAccessor(const C* data, size_t num_data) : data_{data}, num_data_{num_data} {} | ||
|
||
inline size_t size() const { return num_data_; } | ||
|
||
inline Ref operator[](size_t index) const { | ||
if (index >= num_data_) { | ||
throw std::runtime_error("Index out of range"); | ||
} | ||
|
||
return Ref(&data_[index]); | ||
} | ||
|
||
inline ArrayIterator<C, Ref> begin() const { return ArrayIterator<C, Ref>{0, this}; } | ||
|
||
inline ArrayIterator<C, Ref> end() const { return ArrayIterator<C, Ref>{num_data_, this}; } | ||
|
||
private: | ||
const C* data_; | ||
size_t num_data_; | ||
}; | ||
|
||
/*! \brief A specialization of ArrayAccessor for String. | ||
* This class is needed because the String constructor signature is different from the typical | ||
* Metadata subclass. | ||
*/ | ||
template <> | ||
class ArrayAccessor<const char*, ::tvm::runtime::String> { | ||
public: | ||
using value_type = ::tvm::runtime::String; | ||
using iterator = ArrayIterator<const char*, ::tvm::runtime::String>; | ||
using const_iterator = iterator; | ||
|
||
ArrayAccessor(const char** data, size_t num_data) : data_{data}, num_data_{num_data} {} | ||
|
||
inline size_t size() const { return num_data_; } | ||
|
||
inline ::tvm::runtime::String operator[](size_t index) const { | ||
if (index >= num_data_) { | ||
throw std::runtime_error("Index out of range"); | ||
} | ||
return ::tvm::runtime::String(data_[index]); | ||
} | ||
|
||
inline ArrayIterator<const char*, ::tvm::runtime::String> begin() const { | ||
return ArrayIterator<const char*, ::tvm::runtime::String>{0, this}; | ||
} | ||
|
||
inline ArrayIterator<const char*, ::tvm::runtime::String> end() const { | ||
return ArrayIterator<const char*, ::tvm::runtime::String>{num_data_, this}; | ||
} | ||
|
||
private: | ||
const char** data_; | ||
size_t num_data_; | ||
}; | ||
|
||
/*! \brief Enumerates the primitive types which can be part of a Metadata instance. | ||
* | ||
* These are separate from TIR DataType because TIR does not model structs. | ||
*/ | ||
enum MetadataTypeIndex : uint8_t { | ||
kUint64 = 0, | ||
kInt64 = 1, | ||
kBool = 2, | ||
kString = 3, | ||
kHandle = 4, | ||
kMetadata = 5, | ||
}; | ||
|
||
/*! \brief Container for arrays in the metadata. | ||
* | ||
* Type information is needed when emitting arrays. This container augments the data field with | ||
* the necessary typing information. | ||
*/ | ||
class MetadataArrayNode : public MetadataBaseNode { | ||
public: | ||
MetadataArrayNode(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name) | ||
: array(::std::move(array)), type_index{type_index}, struct_name{struct_name} {} | ||
|
||
Array<ObjectRef> array; | ||
MetadataTypeIndex type_index; | ||
const char* struct_name; | ||
static constexpr const char* _type_key = "metadata.MetadataArrayNode"; | ||
TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode); | ||
}; | ||
|
||
/*! \brief Reference class for MetadataArray. */ | ||
class MetadataArray : public MetadataBase { | ||
public: | ||
MetadataArray(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name); | ||
|
||
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode); | ||
}; | ||
|
||
} // namespace metadata | ||
} // namespace runtime | ||
} // namespace tvm | ||
|
||
#endif // TVM_RUNTIME_METADATA_BASE_H_ |
Oops, something went wrong.