Skip to content

Commit

Permalink
[AOT] Support LLVM backend with C++ runtime (apache#10753)
Browse files Browse the repository at this point in the history
* add get_c_struct_name() method to Metadata to distinguish struct type name in llvm

* add metadata serialization support to llvm codegen

* Organize MetadataQueuer into a separate file.

* Add DiscoverArraysVisitor to metadata_utils

* Fill DLTensor metadata in LegalizePackedCalls.

* Improve error message from Call asserts

* Pass non-String device_context down to codegen.

 * this is necessary to allow CodeGenCPU to emit calls that include resource_handle.

* Scope usage of lvalue refs in LowerTVMBuiltin to avoid corrupt memory.

* test fixes

* Also fill preflattened_buffer_map (TODO, maybe don't do this)

* Fix C codegen.

* Set USMP elem_offset to 0.

* Clarify calculation of byte_offset from elem_offset.

* fix tests

* Fix arm compile warning

* Fix hexagon test.

 * previously I believe we required interface_api == "c", but
   this really means to generate C API bindings, and we are generating
   "packed" bindings.
 * I think "c" was chosen here because the distinction between
   interface-api and use-unpacked-api is confusing. "c" interface-api
   means to generate an entrypoint API for microcontrollers that
   accepts bare data buffers. "packed" interface-api means to generate
   a TVMBackendPackedCFunc entrypoint. use-unpacked-api forms the same
   determination for the operator functions.
 * A further confusion here is that there are two ways to call
   "packed" operator functions: tir.tvm_builtin_call_packed and
   tir.tvm_builtin_call_cpacked. This distinction describes whether or
   not to late-bind calls via TVMBackendGetFuncFromEnv. Right now, AOT
   only ever requires call_cpacked because target_host == target, and
   for all suitable target_host, we expect a single DSO-exportable
   runtime.Module. When we move away from this by introducing
   heterogeneous target support to AOT, we can use this as a condition
   to help us choose between call_cpacked and call_packed (and
   possibly add a compile-time option to assert it is call_cpacked,
   for situations where we really don't want call_packed).

* Document T.preflattened_buffer

* Fix test_aot_legalize_packed_calls

* Address manupa comments

* Fix convert_pool_allocations_to_offsets test.

* lint

* Fix T.preflattened_buffer

* Add preflattened_buffer_map to TIRTextPrinter

* Fix tests

* Fix BYOC

* Fix invoking C device API.

* remove comments

* Address Mousius comments

* lint

* lint

* Fix GMock linking on new CMake

* address masahi comment

Co-authored-by: Masahiro Masuda <masahi129@gmail.com>
  • Loading branch information
2 people authored and Lucien0 committed Apr 19, 2022
1 parent 9e7bac7 commit ad76dfa
Show file tree
Hide file tree
Showing 35 changed files with 1,667 additions and 527 deletions.
21 changes: 20 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,25 @@ if(USE_GTEST)
find_package(GTest REQUIRED)
endif()
if(GTEST_FOUND)
if(NOT TARGET GTest::gmock)
# GMock is formally supported in CMake 3.20; for now, expect libgmock.a in the same directory,
# and require that folks compiling against GTest::gmock also link against GTest::GTest
# (for the includes dir).
add_library(GTest::gmock STATIC IMPORTED GLOBAL)
get_target_property(GTEST_LIB_PATH GTest::GTest IMPORTED_LOCATION)
if("${GTEST_LIB_PATH}" STREQUAL "GTEST_LIB_PATH-NOTFOUND")
# CMake >= 3.20 makes GTest::GTest into a compatibility target. The real import location is in
# GTest::gtest.
get_target_property(GTEST_LIB_PATH GTest::gtest IMPORTED_LOCATION)
if("${GTEST_LIB_PATH}" STREQUAL "GTEST_LIB_PATH-NOTFOUND")
message(FATAL_ERROR "Neither GTest::GTest nor GTets::gtest targets defined IMPORTED_LOCATION")
endif()
endif()
get_filename_component(GTEST_LIB_DIR "${GTEST_LIB_PATH}" DIRECTORY)
set_target_properties(GTest::gmock PROPERTIES
IMPORTED_LOCATION "${GTEST_LIB_DIR}/libgmock.a")
endif()

enable_testing()
include(CTest)
endif()
Expand Down Expand Up @@ -626,7 +645,7 @@ if(GTEST_FOUND)
add_executable(cpptest ${TEST_SRCS})
# include runtime files for unit testing
target_include_directories(cpptest PUBLIC "src/runtime")
target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} GTest::GTest GTest::Main pthread dl)
target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} GTest::GTest GTest::Main GTest::gmock pthread dl)
set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_ALL 1)
set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1)
# For some reason, compile definitions are not propagated correctly, so we manually add them here
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/runtime/metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class MetadataNode : public MetadataBaseNode {
public:
explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {}
static constexpr const char* _type_key = "metadata.MetadataNode";
const char* get_c_struct_name() const override;
inline int64_t version() const { return int64_t(data_->version); }
inline int64_t num_inputs() const { return data_->num_inputs; }
ArrayAccessor<struct TVMTensorInfo, TensorInfo> inputs();
Expand All @@ -141,6 +142,7 @@ class TensorInfoNode : public MetadataBaseNode {
public:
explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {}
static constexpr const char* _type_key = "metadata.TensorInfoNode";
const char* get_c_struct_name() const override;
inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); }
inline int64_t num_shape() const { return data_->num_shape; }
inline ::tvm::support::Span<const int64_t, int64_t> shape() const {
Expand Down
31 changes: 25 additions & 6 deletions include/tvm/runtime/metadata_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ namespace metadata {
*/
class MetadataBaseNode : public ::tvm::runtime::Object {
public:
virtual const char* get_c_struct_name() const = 0;

static constexpr const char* _type_key = "metadata.MetadataBaseNode";
TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object);
};
Expand Down Expand Up @@ -157,7 +159,7 @@ class ArrayAccessor<const char*, ::tvm::runtime::String> {
*
* These are separate from TIR DataType because TIR does not model structs.
*/
enum MetadataTypeIndex : uint8_t {
enum MetadataKind : uint8_t {
kUint64 = 0,
kInt64 = 1,
kBool = 2,
Expand All @@ -173,20 +175,37 @@ enum MetadataTypeIndex : uint8_t {
*/
class MetadataArrayNode : public MetadataBaseNode {
public:
MetadataArrayNode(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name)
: array(::std::move(array)), type_index{type_index}, struct_name{struct_name} {}
MetadataArrayNode(Array<ObjectRef> array, MetadataKind kind, const char* type_key)
: array(::std::move(array)), kind{kind}, type_key{type_key} {}

const char* get_c_struct_name() const final;

std::string get_element_c_struct_name() const {
CHECK(kind == MetadataKind::kMetadata)
<< "cannot get struct name for MetadataArray with kind=" << kind;
constexpr int prefix_size = sizeof("metadata.") - 1;
constexpr int suffix_size = sizeof("Node") - 1;
std::string type_key_str(type_key);
return std::string("TVM") +
type_key_str.substr(prefix_size, type_key_str.size() - prefix_size - suffix_size);
}

Array<ObjectRef> array;
MetadataTypeIndex type_index;
const char* struct_name;

/*! \brief Describes the storage class of the emitted struct member. */
MetadataKind kind;

/*! \brief When `kind` is Metadata, type_key of the MetadataBaseNode used with this array. */
const char* type_key;

static constexpr const char* _type_key = "metadata.MetadataArrayNode";
TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode);
};

/*! \brief Reference class for MetadataArray. */
class MetadataArray : public MetadataBase {
public:
MetadataArray(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name);
MetadataArray(Array<ObjectRef> array, MetadataKind kind, const char* struct_name);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode);
};
Expand Down
21 changes: 20 additions & 1 deletion python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,8 @@ class PreflattenedBufferMap(SpecialStmt):
Example
-------
.. code-block:: python
T.preflattened_buffer_map({})
A0 = T.match_buffer(A, (48,), dtype="float32")
T.preflattened_buffer_map(A, (1, 4, 4, 3), elem_offset=1, align=4, dtype="float32")
"""

def __init__(self):
Expand All @@ -892,12 +893,30 @@ def preflattened_buffer(
for key, value in self.context.func_buffer_map.items():
if value.same_as(postflattened):
param = key
break

assert (
param is not None
), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map."

if data is None:
data = self.context.func_buffer_map[param].data

buffer_name: str = f"{postflattened.name}_preflatten"
if align != -1:
if isinstance(align, IntImm):
align = align.value
else:
assert isinstance(align, int), f"align: want int or IntImm, got {align!r}"

if offset_factor != 0:
if isinstance(offset_factor, IntImm):
offset_factor = offset_factor.value
else:
assert isinstance(
offset_factor, int
), f"offset_factor: want int or IntImm, got {offset_factor!r}"

preflattened = tvm.tir.decl_buffer(
shape,
dtype,
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/testing/tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
# pylint: disable=invalid-name, import-outside-toplevel, unused-variable
"""Common utility functions in TVM tir"""
import inspect
import re
import tvm
from tvm.ir.diagnostics import override_renderer


CHECK_ERROR_RE = re.compile(r"^.*# check_error: (.+)$")


def check_error(func, rel_lineno):
"""check if TIR script throws error"""
# Override the default renderer to accumulate errors
Expand All @@ -46,3 +50,12 @@ def render(e):
assert (
d.span.line - 1 == rel_lineno
), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}"

error_line = source_code.split("\n")[rel_lineno]
m = CHECK_ERROR_RE.match(error_line)
if m:
expected_error_text = m.group(1)
errors = [e.message for e in errors]
assert (
expected_error_text in errors
), f'check_error expects "{expected_error_text} in str(errors): {errors}'
11 changes: 11 additions & 0 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
doc << Doc::Indent(
2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
}

if (op->preflattened_buffer_map.size() != 0) {
// print preflattened_buffer_map
std::vector<Doc> preflattened_buffer_map_doc;
for (auto& v : op->preflattened_buffer_map) {
preflattened_buffer_map_doc.push_back(Print(v.first) << ": " << Print(v.second));
}
doc << Doc::Indent(2, Doc::NewLine()
<< "preflattened_buffer_map = {"
<< PrintSep(preflattened_buffer_map_doc, Doc::Text(", ")) << "}");
}
doc << PrintBody(op->body);
return doc;
}
Expand Down
Loading

0 comments on commit ad76dfa

Please sign in to comment.