Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AOT] Support LLVM backend with C++ runtime #10753

Merged
merged 33 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
3c44f7b
add get_c_struct_name() method to Metadata to distinguish struct type…
masahi Feb 28, 2022
783aa43
add metadata serialization support to llvm codegen
masahi Feb 28, 2022
f991e19
Organize MetadataQueuer into a separate file.
areusch Mar 4, 2022
f6b1314
Add DiscoverArraysVisitor to metadata_utils
areusch Mar 22, 2022
e8e93fc
Fill DLTensor metadata in LegalizePackedCalls.
areusch Mar 22, 2022
1477383
Improve error message from Call asserts
areusch Mar 22, 2022
1641cfd
Pass non-String device_context down to codegen.
areusch Mar 22, 2022
58da108
Scope usage of lvalue refs in LowerTVMBuiltin to avoid corrupt memory.
areusch Mar 22, 2022
cab2df8
test fixes
areusch Mar 22, 2022
43ad6d4
Also fill preflattened_buffer_map (TODO, maybe don't do this)
areusch Mar 22, 2022
acba246
Fix C codegen.
areusch Mar 22, 2022
fe910e9
Set USMP elem_offset to 0.
areusch Mar 22, 2022
1558cf7
Clarify calculation of byte_offset from elem_offset.
areusch Mar 22, 2022
4290e28
fix tests
areusch Mar 22, 2022
74283a7
Fix arm compile warning
areusch Mar 24, 2022
3764385
Fix hexagon test.
areusch Mar 24, 2022
48478c7
Document T.preflattened_buffer
areusch Mar 31, 2022
4bf22e9
Fix test_aot_legalize_packed_calls
areusch Mar 31, 2022
5de35ef
Address manupa comments
areusch Apr 5, 2022
d756d79
Fix convert_pool_allocations_to_offsets test.
areusch Apr 5, 2022
6563534
lint
areusch Apr 7, 2022
e39deed
Fix T.preflattened_buffer
areusch Apr 7, 2022
f2138d5
Add preflattened_buffer_map to TIRTextPrinter
areusch Apr 7, 2022
c257f7f
Fix tests
areusch Apr 7, 2022
4705a18
Fix BYOC
areusch Apr 7, 2022
9642548
Fix invoking C device API.
areusch Apr 8, 2022
66f0898
remove comments
areusch Apr 11, 2022
1b36e6e
Address Mousius comments
areusch Apr 13, 2022
792245a
lint
areusch Apr 13, 2022
8bbb750
Merge remote-tracking branch 'origin/main' into mbmr-aot-llvm
areusch Apr 13, 2022
131722e
lint
areusch Apr 13, 2022
ee1877c
Fix GMock linking on new CMake
areusch Apr 18, 2022
374da00
address masahi comment
areusch Apr 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,25 @@ if(USE_GTEST)
find_package(GTest REQUIRED)
endif()
if(GTEST_FOUND)
if(NOT TARGET GTest::gmock)
# GMock is formally supported in CMake 3.20; for now, expect libgmock.a in the same directory,
# and require that folks compiling against GTest::gmock also link against GTest::GTest
# (for the includes dir).
add_library(GTest::gmock STATIC IMPORTED GLOBAL)
get_target_property(GTEST_LIB_PATH GTest::GTest IMPORTED_LOCATION)
if("${GTEST_LIB_PATH}" STREQUAL "GTEST_LIB_PATH-NOTFOUND")
# CMake >= 3.20 makes GTest::GTest into a compatibility target. The real import location is in
# GTest::gtest.
get_target_property(GTEST_LIB_PATH GTest::gtest IMPORTED_LOCATION)
if("${GTEST_LIB_PATH}" STREQUAL "GTEST_LIB_PATH-NOTFOUND")
message(FATAL_ERROR "Neither GTest::GTest nor GTets::gtest targets defined IMPORTED_LOCATION")
endif()
endif()
get_filename_component(GTEST_LIB_DIR "${GTEST_LIB_PATH}" DIRECTORY)
set_target_properties(GTest::gmock PROPERTIES
IMPORTED_LOCATION "${GTEST_LIB_DIR}/libgmock.a")
endif()

enable_testing()
include(CTest)
endif()
Expand Down Expand Up @@ -626,7 +645,7 @@ if(GTEST_FOUND)
add_executable(cpptest ${TEST_SRCS})
# include runtime files for unit testing
target_include_directories(cpptest PUBLIC "src/runtime")
target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} GTest::GTest GTest::Main pthread dl)
target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} GTest::GTest GTest::Main GTest::gmock pthread dl)
set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_ALL 1)
set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1)
# For some reason, compile definitions are not propagated correctly, so we manually add them here
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/runtime/metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class MetadataNode : public MetadataBaseNode {
public:
explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {}
static constexpr const char* _type_key = "metadata.MetadataNode";
const char* get_c_struct_name() const override;
inline int64_t version() const { return int64_t(data_->version); }
inline int64_t num_inputs() const { return data_->num_inputs; }
ArrayAccessor<struct TVMTensorInfo, TensorInfo> inputs();
Expand All @@ -141,6 +142,7 @@ class TensorInfoNode : public MetadataBaseNode {
public:
explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {}
static constexpr const char* _type_key = "metadata.TensorInfoNode";
const char* get_c_struct_name() const override;
inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); }
inline int64_t num_shape() const { return data_->num_shape; }
inline ::tvm::support::Span<const int64_t, int64_t> shape() const {
Expand Down
31 changes: 25 additions & 6 deletions include/tvm/runtime/metadata_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ namespace metadata {
*/
class MetadataBaseNode : public ::tvm::runtime::Object {
public:
virtual const char* get_c_struct_name() const = 0;

static constexpr const char* _type_key = "metadata.MetadataBaseNode";
TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object);
};
Expand Down Expand Up @@ -157,7 +159,7 @@ class ArrayAccessor<const char*, ::tvm::runtime::String> {
*
* These are separate from TIR DataType because TIR does not model structs.
*/
enum MetadataTypeIndex : uint8_t {
enum MetadataKind : uint8_t {
kUint64 = 0,
kInt64 = 1,
kBool = 2,
Expand All @@ -173,20 +175,37 @@ enum MetadataTypeIndex : uint8_t {
*/
class MetadataArrayNode : public MetadataBaseNode {
public:
MetadataArrayNode(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name)
: array(::std::move(array)), type_index{type_index}, struct_name{struct_name} {}
MetadataArrayNode(Array<ObjectRef> array, MetadataKind kind, const char* type_key)
: array(::std::move(array)), kind{kind}, type_key{type_key} {}

const char* get_c_struct_name() const final;

std::string get_element_c_struct_name() const {
CHECK(kind == MetadataKind::kMetadata)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an ASSERT_THROWS test for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

<< "cannot get struct name for MetadataArray with kind=" << kind;
constexpr int prefix_size = sizeof("metadata.") - 1;
constexpr int suffix_size = sizeof("Node") - 1;
std::string type_key_str(type_key);
return std::string("TVM") +
type_key_str.substr(prefix_size, type_key_str.size() - prefix_size - suffix_size);
}

Array<ObjectRef> array;
MetadataTypeIndex type_index;
const char* struct_name;

/*! \brief Describes the storage class of the emitted struct member. */
MetadataKind kind;

/*! \brief When `kind` is Metadata, type_key of the MetadataBaseNode used with this array. */
const char* type_key;

static constexpr const char* _type_key = "metadata.MetadataArrayNode";
TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode);
};

/*! \brief Reference class for MetadataArray. */
class MetadataArray : public MetadataBase {
public:
MetadataArray(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name);
MetadataArray(Array<ObjectRef> array, MetadataKind kind, const char* struct_name);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode);
};
Expand Down
21 changes: 20 additions & 1 deletion python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,8 @@ class PreflattenedBufferMap(SpecialStmt):
Example
-------
.. code-block:: python
T.preflattened_buffer_map({})
A0 = T.match_buffer(A, (48,), dtype="float32")
T.preflattened_buffer_map(A, (1, 4, 4, 3), elem_offset=1, align=4, dtype="float32")
"""

def __init__(self):
Expand All @@ -892,12 +893,30 @@ def preflattened_buffer(
for key, value in self.context.func_buffer_map.items():
if value.same_as(postflattened):
param = key
break

assert (
param is not None
), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map."

if data is None:
data = self.context.func_buffer_map[param].data

buffer_name: str = f"{postflattened.name}_preflatten"
if align != -1:
if isinstance(align, IntImm):
align = align.value
else:
assert isinstance(align, int), f"align: want int or IntImm, got {align!r}"

if offset_factor != 0:
if isinstance(offset_factor, IntImm):
offset_factor = offset_factor.value
else:
assert isinstance(
offset_factor, int
), f"offset_factor: want int or IntImm, got {offset_factor!r}"

Comment on lines +902 to +919
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add some unit tests for these cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

preflattened = tvm.tir.decl_buffer(
shape,
dtype,
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/testing/tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
# pylint: disable=invalid-name, import-outside-toplevel, unused-variable
"""Common utility functions in TVM tir"""
import inspect
import re
import tvm
from tvm.ir.diagnostics import override_renderer


CHECK_ERROR_RE = re.compile(r"^.*# check_error: (.+)$")


def check_error(func, rel_lineno):
"""check if TIR script throws error"""
# Override the default renderer to accumulate errors
Expand All @@ -46,3 +50,12 @@ def render(e):
assert (
d.span.line - 1 == rel_lineno
), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}"

error_line = source_code.split("\n")[rel_lineno]
m = CHECK_ERROR_RE.match(error_line)
if m:
expected_error_text = m.group(1)
errors = [e.message for e in errors]
assert (
expected_error_text in errors
), f'check_error expects "{expected_error_text} in str(errors): {errors}'
11 changes: 11 additions & 0 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
doc << Doc::Indent(
2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
}

if (op->preflattened_buffer_map.size() != 0) {
// print preflattened_buffer_map
std::vector<Doc> preflattened_buffer_map_doc;
for (auto& v : op->preflattened_buffer_map) {
preflattened_buffer_map_doc.push_back(Print(v.first) << ": " << Print(v.second));
}
doc << Doc::Indent(2, Doc::NewLine()
<< "preflattened_buffer_map = {"
<< PrintSep(preflattened_buffer_map_doc, Doc::Text(", ")) << "}");
}
doc << PrintBody(op->body);
return doc;
}
Expand Down
Loading