From 5506472ae87e80fc0b81036beac51d5e0cdac294 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Tue, 28 Sep 2021 10:15:00 -0700 Subject: [PATCH 01/20] [microTVM][Zephyr] Add MIMXRT1050 board support (#9068) * add target support * fix ci issue --- apps/microtvm/zephyr/template_project/boards.json | 6 ++++++ .../zephyr/template_project/microtvm_api_server.py | 12 +++++++++--- python/tvm/target/target.py | 2 +- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/apps/microtvm/zephyr/template_project/boards.json b/apps/microtvm/zephyr/template_project/boards.json index bdfa51109ff7..aabed3322150 100644 --- a/apps/microtvm/zephyr/template_project/boards.json +++ b/apps/microtvm/zephyr/template_project/boards.json @@ -1,4 +1,10 @@ { + "mimxrt1050_evk": { + "board": "mimxrt1050_evk", + "model": "imxrt10xx", + "is_qemu": false, + "fpu": true + }, "mps2_an521": { "board": "mps2_an521", "model": "mps2_an521", diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py index ed275e610912..f700b5774c72 100644 --- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py @@ -162,6 +162,7 @@ def _get_device_args(options): "nucleo_l4r5zi": {"idVendor": 0x0483, "idProduct": 0x374B}, "nucleo_f746zg": {"idVendor": 0x0483, "idProduct": 0x374B}, "stm32f746g_disco": {"idVendor": 0x0483, "idProduct": 0x374B}, + "mimxrt1050_evk": {"idVendor": 0x1366, "idProduct": 0x0105}, } @@ -545,6 +546,10 @@ def _find_openocd_serial_port(cls, options): return ports[0].device + @classmethod + def _find_jlink_serial_port(cls, options): + return cls._find_openocd_serial_port(options) + @classmethod def _find_serial_port(cls, options): flash_runner = _get_flash_runner() @@ -555,9 +560,10 @@ def _find_serial_port(cls, options): if flash_runner == "openocd": return cls._find_openocd_serial_port(options) - raise FlashRunnerNotSupported( - f"Don't know how to deduce serial port for flash runner {flash_runner}" - ) + if flash_runner == "jlink": + return cls._find_jlink_serial_port(options) + + raise RuntimeError(f"Don't know how to deduce serial port for flash runner {flash_runner}") def __init__(self, options): self._options = options diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 4ce888170134..4e5826f5b2a2 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -286,7 +286,7 @@ def intel_graphics(model="unknown", options=None): "atsamd51": ["-mcpu=cortex-m4"], "cxd5602gg": ["-mcpu=cortex-m4"], "esp32": [], - "imxrt1060": ["-mcpu=cortex-m7"], + "imxrt10xx": ["-mcpu=cortex-m7"], "mps2_an521": ["-mcpu=cortex-m33"], "nrf52840": ["-mcpu=cortex-m4"], "nrf5340dk": ["-mcpu=cortex-m33"], From 163322cc920ae916850cfb47366066cda634766e Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Tue, 28 Sep 2021 10:27:23 -0700 Subject: [PATCH 02/20] [Relay] Prepare for new plan_devices.cc (part II) (#9130) * Prepare for new plan_devices.cc (part II) These changes came from changing https://github.com/apache/tvm/pull/9038 to use tvm.parser.fromtext instead of manual AST construction. - Demote FunctionOnDeviceAttrs to just a pair of DictAttrs entries so that the parser will understand them on Function definitions. - Connect some special operators to their attributes so parsing understands them at call sites. - Don't silently ignore attributes during parsing. - Implement OptFunctionOnDevice so won't add device annotations for kUnknownDeviceType. - Allow the parser to be given an initial metadata map to support examples which need constants. - More DLOG -> VLOG conversions to reduce debug clutter. * [checkpoint] Keep existing ParseModule ffi to simplify rust bindings * [checkpoint] Address Christopher's comments. * [checkpoint] Andrew's comments from #9038 * [checkpoint] Jared's comments from #9038 * [checkpoint] Woops, forgot rename. --- include/tvm/ir/attrs.h | 25 +++++++ include/tvm/ir/function.h | 21 ++++++ include/tvm/parser/parser.h | 8 ++- include/tvm/relay/attrs/annotation.h | 41 ++++++++++-- include/tvm/relay/attrs/function.h | 66 ------------------- python/tvm/parser/__init__.py | 6 +- src/ir/diagnostic.cc | 6 +- src/parser/meta_ref.h | 3 +- src/parser/parser.cc | 52 ++++++++++----- src/parser/source_map.cc | 6 +- src/relay/op/annotation/annotation.cc | 50 +++++++------- src/relay/op/annotation/annotation.h | 23 +++++-- src/relay/op/memory/device_copy.cc | 1 + src/relay/op/memory/device_copy.h | 2 +- src/relay/op/memory/memory.cc | 2 + src/relay/op/vm/vm.cc | 3 + src/relay/transforms/type_infer.cc | 1 - .../relay/op/annotation/test_annotation.py | 11 ++-- tests/python/relay/test_ir_parser.py | 25 ++++++- 19 files changed, 211 insertions(+), 141 deletions(-) delete mode 100644 include/tvm/relay/attrs/function.h diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index fa1861051e2f..715c96eb6ea5 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -357,6 +357,31 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v return input; } +/*! + * \brief Copy the function or module, but overrides the attributes with the entries from \p attrs. + * + * \param input The thing to annotate (BaseFunc or IRModule) + * \param attrs Key/values attributes to add to \p input. + * + * \tparam TFunc The corresponding function or module type. + * + * \returns The new function or module with updated attributes. + */ +template +inline TFunc WithAttrs(TFunc input, Map attrs) { + using TNode = typename TFunc::ContainerType; + static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); + TNode* node = input.CopyOnWrite(); + if (node->attrs.defined()) { + for (const auto& pair : attrs) { + node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second); + } + } else { + node->attrs = DictAttrs(std::move(attrs)); + } + return input; +} + // Namespace containing detail implementations namespace detail { using runtime::TVMArgValue; diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 13b984d9cb35..5ee719f9964f 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -189,6 +189,27 @@ constexpr const char* kTarget = "target"; * Type: String */ constexpr const char* kGlobalSymbol = "global_symbol"; + +/*! + * \brief The device type which will hold each of the functions parameters. + * + * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but + * may be included as an annotation on user programs. + * + * Type: Array (but interpreted as Array) + */ +constexpr const char* kParamDeviceTypes = "param_device_types"; + +/*! + * \brief The device type which will hold the function result. + * + * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but + * may be included as an annotation on user programs. + * + * Type: Integer (but interpreted as DLDeviceType) + */ +constexpr const char* kResultDeviceType = "result_device_type"; + } // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/include/tvm/parser/parser.h b/include/tvm/parser/parser.h index 7673eec2a337..8c2722050905 100644 --- a/include/tvm/parser/parser.h +++ b/include/tvm/parser/parser.h @@ -23,6 +23,7 @@ * \file parser.h * \brief A parser for TVM IR. */ +#include #include #include @@ -32,8 +33,11 @@ namespace tvm { namespace parser { -IRModule ParseModule(std::string file_name, std::string file_content, - Optional init_module = Optional()); +using MetaTable = Map>; + +IRModule ParseModule(const std::string& file_name, const std::string& file_content, + const Optional& init_module = Optional(), + const MetaTable& init_meta_table = MetaTable()); } // namespace parser } // namespace tvm diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index bc55965ee852..85ac3f36ff60 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -32,17 +32,44 @@ namespace tvm { namespace relay { /*! - * \brief Attributes for the "on_device" operator. + * \brief Attributes for the "on_device" special operator. * - * The relay call + * The Relay call (aka 'annotation'): * \code - * on_device(expr, device_type=2) + * on_device(sub_expr, device_type=2) * \endcode - * denotes that the result of \p expr should be stored on the device with \p DLDeviceType 2 - * (i.e. \p kDLCuda). Semantically the operator is the identity function. + * constrains \p sub_expr to execute and store its result on a device with \p DLDeviceType \p 2 + * (i.e. a \p kDLCuda device). However the annotation itself may appear in an expression to be + * executed and stored on a different device. If so the compiler will automatically insert a + * "device_copy" call to mediate the transition between devices. * - * See also FunctionOnDeviceAttrs in include/relay/attrs/function.h for the function-level - * companion. + * E.g.: Assuming %x and %y reside on the GPU and %z on the CPU then: + * \code + * multiply(on_device(add(%x, %y), device_type=2), %z) + * \endcode + * indicates the \p add should execute on the GPU but the \p multiply should execute on the CPU. + * The compiler will rewrite this to: + * \code + * multiply(device_copy(add(%x, %y), src_dev_type=2, dst_dev_type=1), %z) + * \endcode + * + * The Relay call + * \code + * on_device(sub_expr, device_type=2, is_fixed=True) + * \endcode + * is similar to the above, however the annotation itself must appear in an expression on the + * same device. The compiler will check the devices are consistent, and will not insert any + * "device_copy" call. This form of annotation shouldn't be necessary in user programs. However + * it is needed by the \p PlanDevices pass to fully specify the results of device planning so that + * the pass is idempotent. + * + * E.g.: The following program is equivalent to the above: + * \code + * let %a = on_device(add(%x, %y), device_type=2, is_fixed=True) + * multiply(device_copy(%a, src_dev_type=2, dst_dev_type=1), %z) + * \endcode + * The "on_device" annotation with \p is_fixed=True indicates unambiguously that \p %a is stored + * on the GPU. */ struct OnDeviceAttrs : public tvm::AttrsNode { // TODO(mbs): Replace device types with TargetDevice. diff --git a/include/tvm/relay/attrs/function.h b/include/tvm/relay/attrs/function.h deleted file mode 100644 index f4f94131da1f..000000000000 --- a/include/tvm/relay/attrs/function.h +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/relay/attrs/function.h - * \brief Attributes for Relay Functions which don't make sense on PrimFuncs. - */ -#ifndef TVM_RELAY_ATTRS_FUNCTION_H_ -#define TVM_RELAY_ATTRS_FUNCTION_H_ - -namespace tvm { -namespace relay { -/*! - * \brief Attributes for Relay function definitions which capture the devices for the - * function parameters and result. - * - * See also OnDeviceAttrs in include/tvm/relay/attrs/annotation.h for the companion "on_device" - * call attributes. - */ -struct FunctionOnDeviceAttrs : public tvm::AttrsNode { - /*! \brief Device type on which each of the function's arguments already resides. */ - Array param_device_types; - // TODO(mbs): Replace device types with TargetDevice. - /*! \brief Device type on which function body should be evaluated. */ - int result_device_type = kInvalidDeviceType; - - TVM_DECLARE_ATTRS(FunctionOnDeviceAttrs, "relay.attrs.FunctionOnDeviceAttrs") { - TVM_ATTR_FIELD(param_device_types) - .describe("The type of the virtual device which holds each function parameters."); - TVM_ATTR_FIELD(result_device_type) - .describe("The type of the virtual device which will hold the function's result.") - .set_default(0); - } -}; - -namespace attr { - -/*! - * \brief Device annotations for function parameters and results. - * - * Type: FunctionOnDeviceAttrs - */ -constexpr static const char* kFunctionAttrsKey = "on_device"; - -} // namespace attr - -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_ATTRS_FUNCTION_H_ diff --git a/python/tvm/parser/__init__.py b/python/tvm/parser/__init__.py index 60fcddb17f08..d75ad16ebab2 100644 --- a/python/tvm/parser/__init__.py +++ b/python/tvm/parser/__init__.py @@ -26,8 +26,10 @@ def add(self, name, content): return _ffi.get_global_func("SourceMapAdd")(self, name, content) -def parse(source, source_name="from_string"): - return _ffi_api.ParseModule(source_name, source) +def parse(source, source_name="from_string", init_module=None, init_meta_table=None): + if init_meta_table is None: + init_meta_table = {} + return _ffi_api.ParseModuleInContext(source_name, source, init_module, init_meta_table) def parse_expr(source): diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 876113b85f6e..b9677d198eba 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -242,10 +242,10 @@ void ReportAt(const DiagnosticContext& context, std::ostream& out, const Span& s } auto source = (*it).second; - DLOG(INFO) << "Source: " << std::endl << source->source; + VLOG(1) << "Source: " << std::endl << source->source; - DLOG(INFO) << "ReportAt " - << "span = " << span << " msg = " << diagnostic->message; + VLOG(1) << "ReportAt " + << "span = " << span << " msg = " << diagnostic->message; auto line_text = source.GetLine(span->line); diff --git a/src/parser/meta_ref.h b/src/parser/meta_ref.h index 481f334cb0fe..483b7f726e07 100644 --- a/src/parser/meta_ref.h +++ b/src/parser/meta_ref.h @@ -26,6 +26,7 @@ #define TVM_PARSER_META_REF_H_ #include +#include #include #include @@ -36,8 +37,6 @@ namespace parser { using namespace relay; -using MetaTable = Map>; - /*! * \brief Options for allocating storage. */ diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 93dc687d72f5..5eec716cc20c 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1092,8 +1092,6 @@ class Parser { Array generics; if (Peek()->token_type == TokenType::kLSquare) { - // If we have generics we need to add a type scope. - PushTypeScope(); generics = ParseSequence( TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { auto type_var_name = Match(TokenType::kIdentifier).ToString(); @@ -1444,6 +1442,10 @@ class Parser { ICHECK(attr_obj.defined()); attrs = Downcast(attr_obj); } + } else { + this->diag_ctx.EmitFatal(Diagnostic::Error(op->span) + << "unable to determine the 'attrs_type_key' with which " + "to represent the call attributes for this operator"); } } return true; @@ -1867,7 +1869,7 @@ class Parser { }; Parser InitParser(const std::string& file_name, const std::string& file_content, - Optional init_module) { + const Optional& init_module, const MetaTable& init_meta_table) { VLOG(0) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size(); SourceName src_name = SourceName::Get(file_name); Source source(src_name, file_content); @@ -1886,19 +1888,33 @@ Parser InitParser(const std::string& file_name, const std::string& file_content, auto tokens_and_table = Tokenize(diag_ctx, source); auto tokens = tokens_and_table.first; - auto meta_data_table = tokens_and_table.second; + MetaTable meta_data_table = tokens_and_table.second.ToMetadata(); + + // Merge any entries in init_meta_table into anything captured in the #[metadata] section + // of the file_content. Metadata references within file_content must use indexes which account + // for this ordering. + for (const auto& pair : init_meta_table) { + Array items; + if (meta_data_table.count(pair.first)) { + items = meta_data_table[pair.first]; + } + for (const auto& obj : pair.second) { + items.push_back(obj); + } + meta_data_table.Set(pair.first, items); + } - return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), meta_data_table.ToMetadata()); + return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), std::move(meta_data_table)); } -IRModule ParseModule(std::string file_name, std::string file_content, - Optional init_module) { +IRModule ParseModule(const std::string& file_name, const std::string& file_content, + const Optional& init_module, const MetaTable& init_meta_table) { VLOG(0) << "ParseModule"; - auto parser = InitParser(file_name, file_content, init_module); + auto parser = InitParser(file_name, file_content, init_module, init_meta_table); auto mod = parser.ParseModule(); ICHECK(mod.defined()) << "The parser must return a non-null module."; - // NB(@jroesch): it is very important that we render any errors before we procede - // if there were any errors which allow the parser to procede we must render them + // NB(@jroesch): it is very important that we render any errors before we proceed + // if there were any errors which allow the parser to proceed we must render them // here. parser.diag_ctx.Render(); auto infer_type = tvm::relay::transform::InferType(); @@ -1906,22 +1922,28 @@ IRModule ParseModule(std::string file_name, std::string file_content, return infer_type(mod); } -Expr ParseExpr(std::string file_name, std::string file_content) { +Expr ParseExpr(const std::string& file_name, const std::string& file_content) { VLOG(0) << "ParseExpr"; - auto parser = InitParser(file_name, file_content, Optional()); + auto parser = InitParser(file_name, file_content, Optional(), MetaTable()); parser.ParseSemVer(false); parser.PushScope(); auto expr = parser.ParseExpr(); parser.Match(TokenType::kEndOfFile); - // NB(@jroesch): it is very important that we render any errors before we procede - // if there were any errors which allow the parser to procede we must render them + // NB(@jroesch): it is very important that we render any errors before we proceed + // if there were any errors which allow the parser to proceed we must render them // here. parser.diag_ctx.Render(); return expr; } +TVM_REGISTER_GLOBAL("parser.ParseModuleInContext") + .set_body_typed([](const std::string& file_name, const std::string& file_content, + const Optional& init_module, const MetaTable& init_meta_table) { + return ParseModule(file_name, file_content, init_module, init_meta_table); + }); + TVM_REGISTER_GLOBAL("parser.ParseModule") - .set_body_typed([](tvm::String file_name, tvm::String file_content) { + .set_body_typed([](const std::string& file_name, const std::string& file_content) { return ParseModule(file_name, file_content); }); diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index 4e79d0e74c59..3c1329670c40 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -60,7 +60,7 @@ Source::Source(SourceName src_name, std::string source) { } tvm::String Source::GetLine(int line) { - DLOG(INFO) << "Source::GetLine: line=" << line; + VLOG(1) << "Source::GetLine: line=" << line; ICHECK(line - 1 < static_cast((*this)->line_map.size())) << "requested line: " << line << "at index: " << (line - 1) << "line_map size: " << (*this)->line_map.size() << "source: " << (*this)->source; @@ -69,10 +69,10 @@ tvm::String Source::GetLine(int line) { auto range = (*this)->line_map.at(line - 1); int line_start = range.first; int line_length = range.second; - DLOG(INFO) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length; + VLOG(1) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length; // TODO(@jroesch): expose substring on tvm::String. auto line_text = std::string((*this)->source).substr(line_start, line_length); - DLOG(INFO) << "Source::GetLine: line_text=" << line_text; + VLOG(1) << "Source::GetLine: line_text=" << line_text; return line_text; } diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index 4eda15937f3a..284f8b88ee0d 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -26,7 +26,6 @@ #include "./annotation.h" #include -#include #include #include #include @@ -54,7 +53,7 @@ Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, span); } -Expr OptOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { +Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { if (device_type == kInvalidDeviceType) { // Undefined signals no annotation is required. return expr; @@ -92,6 +91,7 @@ RELAY_REGISTER_OP("on_device") .add_argument("data", "Tensor", "The input data.") .set_support_level(10) .add_type_rel("Identity", IdentityRel) + .set_attrs_type_key("relay.attrs.OnDeviceAttrs") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) @@ -128,14 +128,10 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) { return {}; } -TVM_REGISTER_NODE_TYPE(FunctionOnDeviceAttrs); - Function FunctionOnDevice(Function function, Array param_device_types, - DLDeviceType result_device_type) { - auto attrs = make_object(); - attrs->param_device_types = std::move(param_device_types); - attrs->result_device_type = result_device_type; - return WithAttr(std::move(function), attr::kFunctionAttrsKey, Attrs(std::move(attrs))); + Integer result_device_type) { + return WithAttrs(std::move(function), {{tvm::attr::kParamDeviceTypes, param_device_types}, + {tvm::attr::kResultDeviceType, result_device_type}}); } Function FunctionOnDevice(Function function, const std::vector& param_device_types, @@ -143,9 +139,21 @@ Function FunctionOnDevice(Function function, const std::vector& pa Array arr; arr.reserve(param_device_types.size()); for (const auto device_type : param_device_types) { - arr.push_back(static_cast(device_type)); + arr.push_back(static_cast(device_type)); + } + return FunctionOnDevice(std::move(function), std::move(arr), + static_cast(result_device_type)); +} + +Function MaybeFunctionOnDevice(Function function, + const std::vector& param_device_types, + DLDeviceType result_device_type) { + if (std::all_of(param_device_types.begin(), param_device_types.end(), + [](DLDeviceType type) { return type == kInvalidDeviceType; }) && + result_device_type == kInvalidDeviceType) { + return function; } - return FunctionOnDevice(function, arr, result_device_type); + return FunctionOnDevice(function, param_device_types, result_device_type); } TVM_REGISTER_GLOBAL("relay.op.annotation._make.function_on_device") @@ -156,32 +164,26 @@ TVM_REGISTER_GLOBAL("relay.op.annotation._make.function_on_device") }); DLDeviceType GetFunctionResultDeviceType(const FunctionNode* function_node) { - auto opt_attrs = function_node->GetAttr(attr::kFunctionAttrsKey); - if (!opt_attrs) { + auto opt_integer = function_node->GetAttr(tvm::attr::kResultDeviceType); + if (!opt_integer) { // No annotation. return kInvalidDeviceType; } - const auto* opt_function_on_device_attrs = opt_attrs.value().as(); - ICHECK(opt_function_on_device_attrs != nullptr) - << "function '" << attr::kFunctionAttrsKey << "' annotation must be a FunctionOnDeviceAttrs"; - return static_cast(opt_function_on_device_attrs->result_device_type); + return static_cast(opt_integer.value()->value); } DLDeviceType GetFunctionParamDeviceType(const FunctionNode* function_node, size_t i) { ICHECK_LT(i, function_node->params.size()) << "param index " << i << " out of range for function of arity " << function_node->params.size(); - auto opt_attrs = function_node->GetAttr(attr::kFunctionAttrsKey); - if (!opt_attrs) { + auto opt_array = function_node->GetAttr>(tvm::attr::kParamDeviceTypes); + if (!opt_array) { // No annotation. return kInvalidDeviceType; } - const auto* opt_function_on_device_attrs = opt_attrs.value().as(); - ICHECK(opt_function_on_device_attrs != nullptr) - << "function '" << attr::kFunctionAttrsKey << "' annotation must be a FunctionOnDeviceAttrs"; - ICHECK_EQ(opt_function_on_device_attrs->param_device_types.size(), function_node->params.size()) + ICHECK_EQ(opt_array.value().size(), function_node->params.size()) << "annotation parameters do not match function arity"; - return static_cast(opt_function_on_device_attrs->param_device_types[i]->value); + return static_cast(opt_array.value()[i]->value); } Expr StopFusion(Expr data) { diff --git a/src/relay/op/annotation/annotation.h b/src/relay/op/annotation/annotation.h index e3a4aea4708c..643a82116b5b 100644 --- a/src/relay/op/annotation/annotation.h +++ b/src/relay/op/annotation/annotation.h @@ -39,6 +39,8 @@ const Op& OnDeviceOp(); /*! * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed. + * + * See \p OnDeviceAttrs for an overview. */ Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); @@ -52,7 +54,7 @@ Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); * - \p expr is a constructor. There should probably be device polymorphic but are in an * in-between state at the moment. */ -Expr OptOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); +Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); /*! \brief Result of \p GetOnDeviceProps. */ struct OnDeviceProps { @@ -83,24 +85,31 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr); inline bool IsOnDeviceCall(const Expr& expr) { return GetOnDeviceProps(expr).body.defined(); } /*! - * \brief Returns \p function annotated with "on_device" attributes capturing parameter and result - * devices types. However returns \p function directly if all device types are \p - * kInvalidDeviceType. + * \brief Returns \p function annotated with "param_device_types" and "result_device_type" + * attributes capturing parameter and result devices types respectively. */ Function FunctionOnDevice(Function function, Array param_device_types, - DLDeviceType body_device_type); + Integer body_device_type); Function FunctionOnDevice(Function function, const std::vector& param_device_types, DLDeviceType body_device_type); +/*! + * \brief As for \p FunctionOnDevice, but returns \p function unchanged if all parameters and + * result device types are \p kInvalidDeviceType. + */ +Function MaybeFunctionOnDevice(Function function, + const std::vector& param_device_types, + DLDeviceType result_device_type); + /*! * \brief Returns the device type for the resut of \p function_node, or \p kInvalidDeviceType - * if function does not have "on_device" annotation. + * if function does not have "result_device_type" annotation. */ DLDeviceType GetFunctionResultDeviceType(const FunctionNode* function_node); /*! * \brief Returns the device type for the \p i'th parameter of \p function_node, or - * \p kInvalidDeviceType if function does not have "on_device" annotation. + * \p kInvalidDeviceType if function does not have "param_device_types" annotation. */ DLDeviceType GetFunctionParamDeviceType(const FunctionNode* function_node, size_t i); diff --git a/src/relay/op/memory/device_copy.cc b/src/relay/op/memory/device_copy.cc index b94caac2c3d9..dce89aa91b65 100644 --- a/src/relay/op/memory/device_copy.cc +++ b/src/relay/op/memory/device_copy.cc @@ -76,6 +76,7 @@ on different devices. .add_argument("data", "Tensor", "The input data.") .set_support_level(10) .add_type_rel("Identity", IdentityRel) + .set_attrs_type_key("relay.attrs.DeviceCopyAttrs") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) diff --git a/src/relay/op/memory/device_copy.h b/src/relay/op/memory/device_copy.h index d590d8510f17..d21fdb6abe19 100644 --- a/src/relay/op/memory/device_copy.h +++ b/src/relay/op/memory/device_copy.h @@ -45,7 +45,7 @@ Expr DeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type) * a device of type \p src_dev_type but then copied to a device of type \p dst_dev_type. * However, return \p expr directly if \p src_dev_type equals \p dst_dev_type. */ -Expr OptDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type); +Expr MaybeDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type); /*! \brief Result of \p GetDeviceCopyProps. */ struct DeviceCopyProps { diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 68a83ebba1fe..5339d48e3a2f 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -86,6 +86,7 @@ RELAY_REGISTER_OP("memory.alloc_storage") .add_argument("size", "Tensor", "The size of the storage to allocate.") .add_argument("alignment", "Tensor", "The alignment of the storage.") .add_type_rel("AllocStorage", AllocStorageRel) + .set_attrs_type_key("relay.attrs.AllocStorageAttrs") .set_support_level(10) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) @@ -200,6 +201,7 @@ RELAY_REGISTER_OP("memory.alloc_tensor") .add_argument("offset", "Tensor", "The offset into the backing storage.") .add_argument("shape", "Tensor", "The shape of the tensor to allocate.") .add_type_rel("AllocTensor", AllocTensorRel) + .set_attrs_type_key("relay.attrs.AllocTensorAttrs") .set_support_level(10) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) diff --git a/src/relay/op/vm/vm.cc b/src/relay/op/vm/vm.cc index a74a259a114f..be31b5482937 100644 --- a/src/relay/op/vm/vm.cc +++ b/src/relay/op/vm/vm.cc @@ -50,6 +50,7 @@ RELAY_REGISTER_OP("vm.shape_of") .set_num_inputs(1) .add_argument("tensor", "Tensor", "The input tensor") .add_type_rel("ShapeOf", ShapeOfRel) + .set_attrs_type_key("relay.attrs.ShapeOfAttrs") .set_support_level(10) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) @@ -131,6 +132,7 @@ RELAY_REGISTER_OP("vm.shape_func") .add_argument("func", "Function", "The operation to call") .add_argument("ins", "Tuple", "The input tensors.") .add_argument("outs", "Tuple", "The output tensors.") + .set_attrs_type_key("relay.attrs.ShapeFuncAttrs") .add_type_rel("ShapeFuncRel", ShapeFuncRel) .set_support_level(10) .set_attr("TOpPattern", kOpaque) @@ -214,6 +216,7 @@ RELAY_REGISTER_OP("vm.reshape_tensor") .add_argument("data", "Tensor", "The input tensor") .add_argument("shape", "Tensor", "The output shape tensor") .add_type_rel("ReshapeTensor", ReshapeTensorRel) + .set_attrs_type_key("relay.attrs.ReshapeTensorAttrs") .set_support_level(10) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 6c2371716b16..ebdf1fed2fab 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -824,7 +824,6 @@ Pass InferType() { auto pass_info = PassInfo(0, "InferType", {}); return tvm::transform::CreateModulePass( [=](IRModule mod, const PassContext& pass_ctx) { - DLOG(INFO) << "tvm::relay::transform::InferType"; // Execute the pass function and return a new module. IRModule updated_mod = mod->ShallowCopy(); diff --git a/tests/python/relay/op/annotation/test_annotation.py b/tests/python/relay/op/annotation/test_annotation.py index 51daa9aaa06a..58e559eb9680 100644 --- a/tests/python/relay/op/annotation/test_annotation.py +++ b/tests/python/relay/op/annotation/test_annotation.py @@ -54,13 +54,10 @@ def test_function_on_device(): f = relay.Function([x, y], relay.add(x, y)) func = relay.annotation.function_on_device(f, ["cpu", "cuda"], "cuda") assert isinstance(func, relay.Function) - assert len(func.attrs["on_device"].param_device_types) == 2 - assert func.attrs["on_device"].param_device_types[0] == 1 - # ie kDLCPU - assert func.attrs["on_device"].param_device_types[1] == 2 - # ie kDLCUDA - assert func.attrs["on_device"].result_device_type == 2 - # ie KDLCUDA + assert len(func.attrs["param_device_types"]) == 2 + assert func.attrs["param_device_types"][0] == 1 # ie kDLCPU + assert func.attrs["param_device_types"][1] == 2 # ie kDLCUDA + assert func.attrs["result_device_type"] == 2 # ie KDLCUDA if __name__ == "__main__": diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 099e127aeba9..fdbd3924ffb7 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -23,7 +23,6 @@ from numpy import isclose from typing import Union - SEMVER = '#[version = "0.0.5"]\n' BINARY_OPS = { @@ -967,6 +966,30 @@ def test_func_attrs(): assert_parses_as(func.astext(), func) +def test_init_module_and_metatable(): + init_metatable = {"relay.Constant": [relay.const(np.random.rand(2, 3), dtype="float32")]} + init_module = tvm.parser.fromtext( + SEMVER + + """ + def @f(%y : Tensor[(2, 3), float32]) -> Tensor[(2, 3), float32] { + negative(%y) + } + """, + ) + mod = tvm.parser.parse( + SEMVER + + """ + def @main(%x: Tensor[(2, 3), float32]) { + add(@f(%x), meta[relay.Constant][0]) + } + """, + "from_string", + init_module, + init_metatable, + ) + roundtrip(mod) + + if __name__ == "__main__": import sys From 4905a8c4c8b58b90fe4226e32724a1215f2b20b8 Mon Sep 17 00:00:00 2001 From: masahi Date: Wed, 29 Sep 2021 04:35:08 +0900 Subject: [PATCH 03/20] [Torch] Support returning quantized weights and bias for BYOC use cases (#9135) * [Torch] refactored the way is bias quantization done * support returning 8bit weight * add test * add doc * pylint * return_int8_weight -> keep_quantized_weight * fixed for dynamic linear case * remove test function call * simplifying --- python/tvm/relay/frontend/pytorch.py | 25 ++++- python/tvm/relay/frontend/qnn_torch.py | 115 ++++++++++++++++------ tests/python/frontend/pytorch/qnn_test.py | 43 +++++++- 3 files changed, 150 insertions(+), 33 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 39bcfc68e421..56df39fdaa30 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3713,6 +3713,7 @@ def from_pytorch( custom_convert_map=None, default_dtype="float32", use_parser_friendly_name=False, + keep_quantized_weight=False, ): """Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -3745,6 +3746,16 @@ def from_pytorch( so a variable name like "dense.weight" cannot be parsed correctly. Use this option when you want to run the AnnotateSpans pass on the imported module. + keep_quantized_weight : bool + Return quantized weights and bias, rather than float ones. PyTorch stores quantized weights + in a custom format, so we cannot directly access 8 bit weights as Numpy arrays. We use + a PyTorch function to unpack quantized weights into float32 arrays and quantization + parameters. By default, we return float32 weights and rely on the QNN lowering and the + Relay constant folding pass to quantize weights at compile time. In BYOC use cases, however, + we cannot apply the constant folding pass on a QNN graph. If keep_quantized_weight is True, + we quantize weights in the frontend using a function that is equivalent to + qnn.op.quantize(...) operating on Numpy arrays. + Returns ------- mod : tvm.IRModule @@ -3789,9 +3800,17 @@ def from_pytorch( # For quantized models quantized_ops = set(["aten::quantize_per_tensor", "quantized::linear_dynamic"]) if len(quantized_ops.intersection(set(op_names))) > 0: - weight_quant_params = qnn_torch.get_weight_quant_params(script_module) - qnn_torch.add_input_quant_params_to_op_inputs(graph) - qnn_torch.add_quant_params_to_outputs(outputs, packed_param_map, weight_quant_params) + weight_quant_params = qnn_torch.get_weight_quant_params( + script_module, packed_param_map.values() + ) + input_scales_for_bias = qnn_torch.add_input_quant_params_to_op_inputs(graph) + qnn_torch.add_quant_params_to_outputs( + outputs, + packed_param_map, + weight_quant_params, + input_scales_for_bias, + keep_quantized_weight, + ) qnn_torch.add_quant_params(tvm_params, weight_quant_params) converter.update_convert_map(qnn_torch.convert_map) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 9eafae905baf..af3c352d15ae 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -32,16 +32,12 @@ class QNNParam: """A placeholder for weight quantization parameters""" - def __init__(self, weight, bias, scale, zero_point, param_key): - param_prefix = param_key[: -len("._packed_params")] - self.weight_var = _expr.var(param_prefix + "_weight", shape=weight.shape) + def __init__(self, weight, bias, scale, zero_point): self.weight = weight if bias is not None: - self.bias_var = _expr.var(param_prefix + "_bias", shape=bias.shape) self.bias = bias.detach().numpy() else: - self.bias_var = None self.bias = None self.scale = _expr.const(scale) @@ -55,10 +51,8 @@ class ConvPackedParam(QNNParam): together with weights and quantization parameters """ - def __init__( - self, weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups - ): - super().__init__(weight_np, bias, scale, zero_point, param_name) + def __init__(self, weight_np, bias, scale, zero_point, stride, padding, dilation, groups): + super().__init__(weight_np, bias, scale, zero_point) self.stride = stride self.padding = padding self.dilation = dilation @@ -81,23 +75,21 @@ def _get_quant_params(qweight): return weight_np, scales, 0 -def make_qnn_param(param_name, qweight, bias): +def make_qnn_param(qweight, bias): weight_np, scale, zero_point = _get_quant_params(qweight) - return QNNParam(weight_np, bias, scale, zero_point, param_name) + return QNNParam(weight_np, bias, scale, zero_point) -def make_conv_packed_param(param_name, qweight, bias, packed_params): +def make_conv_packed_param(qweight, bias, packed_params): weight_np, scale, zero_point = _get_quant_params(qweight) stride = packed_params.stride() padding = packed_params.padding() dilation = packed_params.dilation() groups = packed_params.groups() - return ConvPackedParam( - weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups - ) + return ConvPackedParam(weight_np, bias, scale, zero_point, stride, padding, dilation, groups) -def get_weight_quant_params(script_module): +def get_weight_quant_params(script_module, packed_param_names): """Retrive and unpack weight parameters from quantized modules""" import torch @@ -114,6 +106,9 @@ def filter_func(named_module): key = name + "." + param_name state_dict = m.state_dict() + if key not in packed_param_names: + continue + if len(state_dict) == 0 and not hasattr(m, param_name): # for v1.6 and above # This case seems to happen if a model is serialized @@ -130,28 +125,87 @@ def filter_func(named_module): if "Conv" in m.original_name and len(state_dict) == 0: qweight, bias = torch.ops.quantized.conv2d_unpack(packed_params) - quant_params[key] = make_conv_packed_param(key, qweight, bias, packed_params) + quant_params[key] = make_conv_packed_param(qweight, bias, packed_params) elif "Conv" in m.original_name: qweight, bias = torch.ops.quantized.conv2d_unpack(packed_params) - quant_params[key] = make_qnn_param(key, qweight, bias) + quant_params[key] = make_qnn_param(qweight, bias) elif m.original_name == "LinearPackedParams": qweight, bias = torch.ops.quantized.linear_unpack(packed_params) - quant_params[key] = make_qnn_param(key, qweight, bias) + quant_params[key] = make_qnn_param(qweight, bias) return quant_params -def add_quant_params_to_outputs(outputs, packed_param_map, quant_params): +def quantize_numpy(weight, scale, zero_point, out_dtype_np): + iinfo = np.iinfo(out_dtype_np) + clip_min = iinfo.min + clip_max = iinfo.max + if len(scale.shape) > 0: + scale = np.reshape(scale, [weight.shape[0]] + [1] * (len(weight.shape) - 1)) + transformed = zero_point + weight / scale + return np.clip(np.round(transformed), clip_min, clip_max).astype(out_dtype_np) + + +def add_quant_params_to_outputs( + outputs, packed_param_map, quant_params, input_scales_for_bias, keep_quantized_weight=False +): """ Add quant params to outputs so that they can be referenced by other ops later. Weights are quantized here. """ for node_name, packed_param_name in packed_param_map.items(): qparam = quant_params[packed_param_name] - qweight = relay.qnn.op.quantize( - qparam.weight_var, qparam.scale, qparam.zero_point, out_dtype="int8", axis=0 - ) - params = [qweight, qparam.scale, qparam.zero_point, qparam.bias_var] + weight_scale = _get_numpy(qparam.scale) + param_prefix = packed_param_name[: -len("._packed_params")] + + if keep_quantized_weight: + qparam.weight_var = _expr.var( + param_prefix + "_weight", shape=qparam.weight.shape, dtype="int8" + ) + qparam.weight = quantize_numpy( + qparam.weight, weight_scale, _get_numpy(qparam.zero_point), np.int8 + ) + qweight = qparam.weight_var + else: + qparam.weight_var = _expr.var( + param_prefix + "_weight", shape=qparam.weight.shape, dtype="float32" + ) + qweight = relay.qnn.op.quantize( + qparam.weight_var, qparam.scale, qparam.zero_point, out_dtype="int8", axis=0 + ) + + if qparam.bias is not None: + float_bias_var = _expr.var( + param_prefix + "_bias", shape=qparam.bias.shape, dtype="float32" + ) + if node_name not in input_scales_for_bias: + # This case is for dynamic quantization, where the input activation scale is + # unknown until runtime. + qparam.bias_var = float_bias_var + qbias = qparam.bias_var + elif keep_quantized_weight: + qparam.bias_var = _expr.var( + param_prefix + "_bias", shape=qparam.bias.shape, dtype="int32" + ) + qparam.bias = quantize_numpy( + qparam.bias, input_scales_for_bias[node_name] * weight_scale, 0, np.int32 + ) + qbias = qparam.bias_var + else: + qparam.bias_var = float_bias_var + qbias = relay.qnn.op.quantize( + qparam.bias_var, + _expr.const(input_scales_for_bias[node_name] * weight_scale), + _expr.const(0, "int32"), + out_dtype="int32", + axis=0, + ) + else: + qbias = None + + quant_params[packed_param_name] = qparam + + params = [qweight, qparam.scale, qparam.zero_point, qbias] if isinstance(quant_params[packed_param_name], ConvPackedParam): params += [qparam.stride, qparam.padding, qparam.dilation, qparam.groups] @@ -367,6 +421,8 @@ def add_input_quant_params_to_op_inputs(graph): need_input_quant_param = set(num_quantized_inputs.keys()) need_input_quant_param.add("quantized::cat") + input_scales_for_bias = {} + for node in graph.nodes(): operator = node.kind() if operator not in need_input_quant_param: @@ -401,6 +457,12 @@ def add_input_quant_params_to_op_inputs(graph): node.addInput(scale) node.addInput(zp) + if "conv2d" in operator or "linear" in operator: + # This is required for quantizing the bias + input_scales_for_bias[node.inputsAt(1).debugName()] = scale.node().f("value") + + return input_scales_for_bias + def add_quant_params(params, quant_params): """Add quant parameters to TVM param map""" @@ -478,10 +540,7 @@ def _do_bias_and_requantize( # Instead, the torch way requires rounding of activation at runtime if bias is not None: - qbias = relay.qnn.op.quantize( - bias, requant_input_scale, _expr.const(0, "int32"), out_dtype="int32", axis=0 - ) - requantize_input = _op.nn.bias_add(output, qbias) + requantize_input = _op.nn.bias_add(output, bias) else: requantize_input = output diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 704245040025..65e5692dc4fb 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -40,9 +40,15 @@ def torch_version_check(): return version.parse(torch.__version__) > version.parse("1.4.0") -def get_tvm_runtime(script_module, input_name, ishape): +def get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=False): input_shapes = [(input_name, ishape)] - mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + mod, params = relay.frontend.from_pytorch( + script_module, input_shapes, keep_quantized_weight=keep_quantized_weight + ) + + if keep_quantized_weight: + for p in params.values(): + assert p.dtype in ["int8", "int32"] with tvm.transform.PassContext(opt_level=3): # test on only cpu for now, torch cannot run quant models on cuda @@ -609,3 +615,36 @@ def test_qnn_mergecomposite(): input_name = "image" run_qnn_mergecomposite(script_module, input_name, inp.shape) + + +def test_keep_quantized_weight(): + qmodules = [] + + for per_channel in [False, True]: + qmodules += [ + ((1, 3, 224, 224), ConvBn(), per_channel), + ((16, 16), Linear(), per_channel), + ] + + for (ishape, raw_module, per_channel) in qmodules: + raw_module.eval() + inp = torch.rand(ishape) + + quantize_model(raw_module, inp, per_channel=per_channel) + script_module = torch.jit.trace(raw_module, inp).eval() + + input_name = "input" + + runtime = get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=False) + runtime.set_input(input_name, inp.numpy().copy()) + runtime.run() + tvm_result = runtime.get_output(0).numpy() + + runtime_int8_weight = get_tvm_runtime( + script_module, input_name, ishape, keep_quantized_weight=True + ) + runtime_int8_weight.set_input(input_name, inp.numpy().copy()) + runtime_int8_weight.run() + tvm_result_int8_weight = runtime_int8_weight.get_output(0).numpy() + + tvm.testing.assert_allclose(tvm_result, tvm_result_int8_weight) From 7adbb27bada4e58612e9beac9035dcc223d67ded Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Tue, 28 Sep 2021 16:29:43 -0500 Subject: [PATCH 04/20] [UnitTests] Enable minimum testing on Vulkan target in CI (#9093) * [UnitTests] Enable minimum testing on Vulkan target in CI - Include the Vulkan runtime in the GPU build. - Run test_target_codegen_vulkan.py as part of the `python3: GPU` CI step. * [CI] Added a dummy task_config_build_gpu_vulkan.sh, to be removed later. The CI builds use the Jenkinsfile located in the ci-docker-staging branch, but the scripts in the PR that is being run. Temporarily adding back a task_config_build_gpu_vulkan.sh, which just calls the renamed task_config_build_gpu_other.sh. --- Jenkinsfile | 2 +- tests/scripts/task_config_build_gpu.sh | 1 + tests/scripts/task_config_build_gpu_other.sh | 35 +++++++++++++++++++ tests/scripts/task_config_build_gpu_vulkan.sh | 21 +++++------ .../task_python_integration_gpuonly.sh | 2 +- tests/scripts/task_python_unittest_gpuonly.sh | 18 ++++++++-- 6 files changed, 62 insertions(+), 17 deletions(-) create mode 100755 tests/scripts/task_config_build_gpu_other.sh diff --git a/Jenkinsfile b/Jenkinsfile index fa1629205080..b2852955323f 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -209,7 +209,7 @@ stage('Build') { make(ci_gpu, 'build', '-j2') pack_lib('gpu', tvm_multilib) // compiler test - sh "${docker_run} ${ci_gpu} ./tests/scripts/task_config_build_gpu_vulkan.sh" + sh "${docker_run} ${ci_gpu} ./tests/scripts/task_config_build_gpu_other.sh" make(ci_gpu, 'build2', '-j2') } } diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index 5f86476c64c7..3a429721709e 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -26,6 +26,7 @@ cp ../cmake/config.cmake . echo set\(USE_CUBLAS ON\) >> config.cmake echo set\(USE_CUDNN ON\) >> config.cmake echo set\(USE_CUDA ON\) >> config.cmake +echo set\(USE_VULKAN ON\) >> config.cmake echo set\(USE_OPENGL ON\) >> config.cmake echo set\(USE_MICRO ON\) >> config.cmake echo set\(USE_MICRO_STANDALONE_RUNTIME ON\) >> config.cmake diff --git a/tests/scripts/task_config_build_gpu_other.sh b/tests/scripts/task_config_build_gpu_other.sh new file mode 100755 index 000000000000..c11669a2ab0d --- /dev/null +++ b/tests/scripts/task_config_build_gpu_other.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file is a compiler test to ensure that runtimes can compile +# correctly, even if they aren't actively tested in the CI. + +set -e +set -u + +mkdir -p build2 +cd build2 +cp ../cmake/config.cmake . + +echo set\(USE_OPENCL ON\) >> config.cmake +echo set\(USE_ROCM ON\) >> config.cmake +echo set\(USE_MICRO ON\) >> config.cmake +echo set\(USE_PROFILER ON\) >> config.cmake +echo set\(USE_LIBBACKTRACE OFF\) >> config.cmake +echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake diff --git a/tests/scripts/task_config_build_gpu_vulkan.sh b/tests/scripts/task_config_build_gpu_vulkan.sh index a5a26a1db0fb..93adc9667da7 100755 --- a/tests/scripts/task_config_build_gpu_vulkan.sh +++ b/tests/scripts/task_config_build_gpu_vulkan.sh @@ -16,18 +16,13 @@ # specific language governing permissions and limitations # under the License. -set -e -set -u +# TODO(Lunderberg): Remove this file once the Jenkinsfile in the +# ci-docker-staging branch no longer references it. -mkdir -p build2 -cd build2 -cp ../cmake/config.cmake . +# This file is a backwards compatibility file, as the TVM CI uses the +# Jenkinsfile from the ci-docker-staging branch, but the task scripts +# from the PR branch. -echo set\(USE_OPENCL ON\) >> config.cmake -echo set\(USE_ROCM ON\) >> config.cmake -echo set\(USE_VULKAN ON\) >> config.cmake -echo set\(USE_MICRO ON\) >> config.cmake -echo set\(USE_PROFILER ON\) >> config.cmake -echo set\(USE_LIBBACKTRACE OFF\) >> config.cmake -echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake -echo set\(USE_CCACHE OFF\) >> config.cmake +set -euo pipefail + +./tests/scripts/task_config_build_gpu_other.sh diff --git a/tests/scripts/task_python_integration_gpuonly.sh b/tests/scripts/task_python_integration_gpuonly.sh index ac09cb5a14a3..36c3883d4379 100755 --- a/tests/scripts/task_python_integration_gpuonly.sh +++ b/tests/scripts/task_python_integration_gpuonly.sh @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -export TVM_TEST_TARGETS="cuda;opencl;metal;rocm;vulkan;nvptx;opencl -device=mali,aocl_sw_emu" +export TVM_TEST_TARGETS="cuda;opencl;metal;rocm;nvptx;opencl -device=mali,aocl_sw_emu" export PYTEST_ADDOPTS="-m gpu $PYTEST_ADDOPTS" export TVM_RELAY_TEST_TARGETS="cuda" export TVM_INTEGRATION_TESTSUITE_NAME=python-integration-gpu diff --git a/tests/scripts/task_python_unittest_gpuonly.sh b/tests/scripts/task_python_unittest_gpuonly.sh index 22f79bc70ec9..54dd085f1817 100755 --- a/tests/scripts/task_python_unittest_gpuonly.sh +++ b/tests/scripts/task_python_unittest_gpuonly.sh @@ -16,8 +16,22 @@ # specific language governing permissions and limitations # under the License. -export TVM_TEST_TARGETS="cuda;opencl;metal;rocm;vulkan;nvptx;opencl -device=mali,aocl_sw_emu" -export PYTEST_ADDOPTS="-m gpu $PYTEST_ADDOPTS" +set -euo pipefail + +export PYTEST_ADDOPTS="-m gpu ${PYTEST_ADDOPTS:-}" + +# Test most of the enabled runtimes here. +export TVM_TEST_TARGETS="cuda;opencl;metal;rocm;nvptx;opencl -device=mali,aocl_sw_emu" export TVM_UNITTEST_TESTSUITE_NAME=python-unittest-gpu ./tests/scripts/task_python_unittest.sh + +# Kept separate to avoid increasing time needed to run CI, testing +# only minimal functionality of Vulkan runtime. +export TVM_TEST_TARGETS="vulkan -from_device=0" +export TVM_UNITTEST_TESTSUITE_NAME=python-unittest-vulkan + +source tests/scripts/setup-pytest-env.sh + +run_pytest ctypes ${TVM_UNITTEST_TESTSUITE_NAME} tests/python/unittest/test_target_codegen_vulkan.py +run_pytest cython ${TVM_UNITTEST_TESTSUITE_NAME} tests/python/unittest/test_target_codegen_vulkan.py From f052abc957f3b33b7532ad01b8969c16b0ba76d5 Mon Sep 17 00:00:00 2001 From: anwang2009 Date: Tue, 28 Sep 2021 14:47:47 -0700 Subject: [PATCH 05/20] add nn.global_avgpool to fq2i (#9137) --- .../relay/transform/fake_quantization_to_integer.py | 11 +++++++++++ .../relay/test_pass_fake_quantization_to_integer.py | 13 +++++++++++++ 2 files changed, 24 insertions(+) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 0ed75191c40d..1adde9a4a430 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -101,6 +101,17 @@ def avgpool2d(expr, type_map): return [out, t] +@register_fake_quantization_to_integer("nn.global_avg_pool2d") +def global_avgpool2d(expr, type_map): + """Rewrite a global_avgpool op""" + arg = expr.args[0] + t = type_map[arg] + arg = relay.op.cast(arg, "int32") + out = relay.op.nn.global_avg_pool2d(arg) + out = relay.op.cast(out, t.dtype) + return [out, t] + + @register_fake_quantization_to_integer("nn.bias_add") def bias_add(expr, type_map): """Rewrite a bias_add op""" diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 3680310b4f92..c49d837ed920 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -268,6 +268,19 @@ def test_fake_quantize_avgpool(): compare_fq_to_int(op, [x_np], True) +def test_fake_quantize_global_avg_pool(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.nn.global_avg_pool2d(x) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np], True) + + def test_fake_quantize_reshape(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") From bffbbe76c2793e138078654f0f00fe6094073687 Mon Sep 17 00:00:00 2001 From: Arun Abraham Date: Wed, 29 Sep 2021 03:19:51 +0530 Subject: [PATCH 06/20] [Relay][ConvertLayout] Support for qnn.conv2d_transpose (#9139) --- python/tvm/relay/qnn/op/layout_conversions.py | 48 +++++++++++++ .../relay/test_pass_convert_op_layout.py | 68 +++++++++++++++++++ 2 files changed, 116 insertions(+) diff --git a/python/tvm/relay/qnn/op/layout_conversions.py b/python/tvm/relay/qnn/op/layout_conversions.py index a7c90daf36a4..1a3b1771d6ce 100644 --- a/python/tvm/relay/qnn/op/layout_conversions.py +++ b/python/tvm/relay/qnn/op/layout_conversions.py @@ -78,3 +78,51 @@ def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layouts): return relay.qnn.op.conv2d(*inputs, **new_attrs) raise ValueError("Layout %s is not yet supported" % desired_data_layout) + + +@reg.register_convert_op_layout("qnn.conv2d_transpose") +def convert_qnn_conv2d_transpose(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for QNN conv2d_transpose op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data and kernel inputs respectively. + + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + # pylint: disable=import-outside-toplevel + from tvm import relay + + assert ( + len(desired_layouts) == 2 + ), "A desired layout is expected for both of qnn.conv2d_transpose's inputs" + desired_data_layout, desired_kernel_layout = map(str, desired_layouts) + assert desired_data_layout != "default", "Data layout cannot be default" + + new_attrs = dict(attrs) + new_attrs["data_layout"] = desired_data_layout + + if desired_kernel_layout != "default": + new_attrs["kernel_layout"] = desired_kernel_layout + return relay.qnn.op.conv2d_transpose(*inputs, **new_attrs) + + # Handle default kernel layouts + if desired_data_layout == "NCHW": + new_attrs["kernel_layout"] = "OIHW" + return relay.qnn.op.conv2d_transpose(*inputs, **new_attrs) + if desired_data_layout == "NHWC": + new_attrs["kernel_layout"] = "HWIO" + return relay.qnn.op.conv2d_transpose(*inputs, **new_attrs) + + raise ValueError("Layout %s is not yet supported" % desired_data_layout) diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index a1965aa2d0c5..9b4d154360b2 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1100,6 +1100,74 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_qnn_conv_transpose_requantize_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8") + weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8") + y = relay.qnn.op.conv2d_transpose( + x, + weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="int32", + ) + y = relay.qnn.op.requantize( + y, + relay.const(1, "float32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "int32"), + out_dtype="int32", + ) + y = relay.nn.relu(y) + y = relay.Function([x, weight], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8") + weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8") + x = relay.layout_transform(x, "NHWC", "NCHW") + weight = relay.layout_transform(weight, "HWIO", "OIHW") + y = relay.qnn.op.conv2d_transpose( + x, + weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + out_dtype="int32", + ) + y = relay.qnn.op.requantize( + y, + relay.const(1, "float32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "int32"), + axis=1, + out_dtype="int32", + ) + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"qnn.conv2d_transpose": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def test_conv_convert_kernel_layout(): """Check that convolution kernel layout is correctly transformed.""" From c2bf39919d1be5184e17c99cede23a2079892eb3 Mon Sep 17 00:00:00 2001 From: sunway Date: Wed, 29 Sep 2021 08:28:11 +0800 Subject: [PATCH 07/20] [BYOC] support arbitrary input dims for add/mul/relu of dnnl c_src codegen (#9127) * support arbitrary input dims for add/mul/relu of dnnl c_src codegen * fix lint * fix Co-authored-by: sunway --- src/relay/backend/contrib/dnnl/codegen.cc | 41 +++++++++----- src/runtime/contrib/dnnl/dnnl.cc | 69 +++++++++++++++++------ src/runtime/contrib/dnnl/dnnl_kernel.h | 9 ++- 3 files changed, 85 insertions(+), 34 deletions(-) diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index f0d360ae8b6d..ae58c2f08e8c 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -54,6 +54,15 @@ inline size_t GetShape1DSize(const Type& type) { return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); } +inline std::string GetShapeString(std::vector shape) { + std::string v = "std::vector{"; + for (auto s : shape) { + v += std::to_string(s) + ","; + } + v += "}"; + return v; +} + std::vector Conv2d(const CallNode* call) { std::vector args; const auto* conv2d_attr = call->attrs.as(); @@ -98,12 +107,8 @@ std::vector Dense(const CallNode* call) { std::vector Relu(const CallNode* call) { std::vector args; auto ishape = GetShape(call->args[0]->checked_type()); - // Args: N, C, H, W - for (auto s : ishape) { - args.push_back(std::to_string(s)); - } - + args.push_back(GetShapeString(ishape)); return args; } @@ -123,15 +128,25 @@ std::vector BatchNorm(const CallNode* call) { return args; } +// should comply with src/runtime/contrib/dnnl/dnnl.cc +#define DNNL_BINARY_ADD 0 +#define DNNL_BINARY_MUL 1 + std::vector Add(const CallNode* call) { std::vector args; auto ishape = GetShape(call->args[0]->checked_type()); - + args.push_back(std::to_string(DNNL_BINARY_ADD)); // Args: H, W - for (auto s : ishape) { - args.push_back(std::to_string(s)); - } + args.push_back(GetShapeString(ishape)); + return args; +} +std::vector Multiply(const CallNode* call) { + std::vector args; + auto ishape = GetShape(call->args[0]->checked_type()); + args.push_back(std::to_string(DNNL_BINARY_MUL)); + // Args: H, W + args.push_back(GetShapeString(ishape)); return args; } @@ -239,11 +254,9 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C using ArgFunType = std::function(const CallNode*)>; static const std::map> op_map = { - {"nn.conv2d", {"dnnl_conv2d", Conv2d}}, - {"nn.dense", {"dnnl_dense", Dense}}, - {"nn.relu", {"dnnl_relu", Relu}}, - {"nn.batch_norm", {"dnnl_bn", BatchNorm}}, - {"add", {"dnnl_add", Add}}, + {"nn.conv2d", {"dnnl_conv2d", Conv2d}}, {"nn.dense", {"dnnl_dense", Dense}}, + {"nn.relu", {"dnnl_relu", Relu}}, {"nn.batch_norm", {"dnnl_bn", BatchNorm}}, + {"add", {"dnnl_binary_op", Add}}, {"multiply", {"dnnl_binary_op", Multiply}}, }; const auto op_name = GetRef(op_node)->name; diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 19b3f796fd33..d1190df91375 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -44,6 +44,32 @@ typedef struct { void** data; } DnnlPackedArgs; +inline dnnl::memory::desc GenDNNLMemDescByShape(const dnnl::memory::dims& shape, + memory::data_type dtype) { + using tag = memory::format_tag; + + dnnl::memory::desc data_md; + + switch (shape.size()) { + case 2: + data_md = dnnl::memory::desc({shape, dtype, tag::ab}); + break; + case 3: + data_md = dnnl::memory::desc({shape, dtype, tag::abc}); + break; + case 4: + data_md = dnnl::memory::desc({shape, dtype, tag::abcd}); + break; + case 5: + data_md = dnnl::memory::desc({shape, dtype, tag::abcde}); + break; + default: + LOG(FATAL) << "Unsupported data shape dimension: " << shape.size(); + break; + } + return data_md; +} + // Read from memory, write to handle inline void read_from_dnnl_memory(void* handle, const memory& mem) { size_t bytes = mem.get_desc().get_size(); @@ -175,16 +201,13 @@ extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int read_from_dnnl_memory(out, dst_memory); } -extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_) { - using tag = memory::format_tag; +extern "C" void dnnl_relu(float* data, float* out, std::vector shape) { using dt = memory::data_type; engine eng(engine::kind::cpu, 0); stream s(eng); - memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_}; - - auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw}; + auto data_md = GenDNNLMemDescByShape(shape, dt::f32); auto data_memory = memory(data_md, eng, data); auto dst_memory = memory(data_md, eng); @@ -241,27 +264,39 @@ extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, flo free(weight); } -extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, int p_C_, int p_H_, - int p_W_) { - using tag = memory::format_tag; +// should comply with src/relay/backend/contrib/dnnl/codegen.cc +#define DNNL_BINARY_ADD 0 +#define DNNL_BINARY_MUL 1 + +extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_type, + std::vector shape) { using dt = memory::data_type; engine eng(engine::kind::cpu, 0); stream s(eng); - memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_}; - - auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw}; - auto weight_md = memory::desc({{data_tz}, dt::f32, tag::nchw}); - auto dst_md = memory::desc({{data_tz}, dt::f32, tag::nchw}); + auto data_md = GenDNNLMemDescByShape(shape, dt::f32); auto data_memory = memory(data_md, eng, data); - auto weight_memory = memory(weight_md, eng, weight); - auto dst_memory = memory(dst_md, eng); + auto weight_memory = memory(data_md, eng, weight); + auto dst_memory = memory(data_md, eng); - auto add_desc = binary::desc(algorithm::binary_add, data_md, weight_md, dst_md); + algorithm algo = algorithm::undef; + switch (algo_type) { + case DNNL_BINARY_ADD: + algo = algorithm::binary_add; + break; + case DNNL_BINARY_MUL: + algo = algorithm::binary_mul; + break; + default: + LOG(FATAL) << "Unsupported dnnl algorithm: " << algo_type; + break; + } + + auto add_desc = binary::desc(algo, data_md, data_md, data_md); auto add_prim_desc = binary::primitive_desc(add_desc, eng); - assert(dst_md == add_prim_desc.dst_desc()); + assert(data_md == add_prim_desc.dst_desc()); auto add = binary(add_prim_desc); add.execute( diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h index f5f28fccd8e7..522313ae5a64 100644 --- a/src/runtime/contrib/dnnl/dnnl_kernel.h +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -26,6 +26,9 @@ #define TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ #include +#include + +#include #include "dnnl.hpp" @@ -54,14 +57,14 @@ extern "C" TVM_DLL void dnnl_fused_conv2d_bias_relu(float* data, float* weights, extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_); -extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_); +extern "C" TVM_DLL void dnnl_relu(float* data, float* out, std::vector shape); extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance, float* out, float* new_mean, float* new_variance, int p_n_, int p_c_, int p_h_, int p_w_, int p_e_); -extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_, - int p_h_, int p_w_); +extern "C" TVM_DLL void dnnl_binary_op(float* data, float* weight, float* out, int binary_algo, + std::vector shape); } // namespace contrib } // namespace runtime From 285dbd840990f85f37dc98f84db143bac3da2169 Mon Sep 17 00:00:00 2001 From: Anastasia Stulova <38433336+AnastasiaStulova@users.noreply.github.com> Date: Wed, 29 Sep 2021 01:58:38 +0100 Subject: [PATCH 08/20] [OpenCL] Remove redundant visit statement in CodeGen. (#9144) Fixes regression with some models on which compilation doesn't terminate. --- src/target/source/codegen_opencl.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 7abff36a3ddb..d93a7fde639a 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -59,8 +59,6 @@ class InferTextureAccess : public StmtExprVisitor { var_access_map_[op->args[0].as()] |= kReadAccess; } else if (op->op.same_as(builtin::texture2d_store())) { var_access_map_[op->args[0].as()] |= kWriteAccess; - } else { - StmtExprVisitor::VisitExpr_(op); } StmtExprVisitor::VisitExpr_(op); } From 0467539b35fe4ab6b577dc60c627f9923e584c5f Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Tue, 28 Sep 2021 19:58:58 -0500 Subject: [PATCH 09/20] [UnitTest] Parametrized test_conv2d_int8_intrinsics (#9143) Parametrized it to get more detailed information while debugging failures in https://github.com/apache/tvm/pull/9091, but isn't semantically part of that PR. --- tests/python/relay/test_op_level2.py | 225 +++++++++++++-------------- 1 file changed, 106 insertions(+), 119 deletions(-) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 44f211dd9f8a..da2877063c45 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1587,156 +1587,143 @@ def test_upsampling3d(): _test_upsampling3d("NDHWC", "trilinear", "align_corners") -@tvm.testing.uses_gpu -def test_conv2d_int8_intrinsics(): - def _compile(ic, oc, target, data_layout, kernel_layout, dtypes): +@pytest.mark.skipif(tvm.target.codegen.llvm_version_major() < 8, reason="Requires LLVM 8") +class TestConv2DInt8Intrinsics: + supported_targets = [ + "llvm -mcpu=nehalem", + "llvm -mcpu=core-avx2", + "llvm -mcpu=skylake-avx512", + "llvm -mcpu=cascadelake", + ] + + unsupported_targets = [ + "llvm -mcpu=x86-64", + ] + + data_layout, kernel_layout = tvm.testing.parameters( + ("NCHW", "OIHW"), + # TODO(@anijain2305, @icemelon9): disable conv2d_int8 for NHWC data layout. + # Re-enable this after adding conv2d_NCHWc_int8 support for NHWC. + # ("NHWC", "HWIO"), + ) + + input_channels, output_channels = tvm.testing.parameters( + # Sweep the input channels to check int8 robustness + # Input channels should be a multiple of 4 internally. + (1, 16), + (4, 16), + (6, 16), + # Sweep the output channels to check int8 robustness + # Output channels should be a multiple of 16 internally. + (8, 4), + (8, 16), + (8, 20), + # Check that both non-divisible oc and ic work + (17, 29), + ) + + @tvm.testing.fixture + def fast_int8_intrinsic(self, target): + if "nehalem" in target or "core-avx2" in target or "skylake-avx512" in target: + return "pmaddubs" + elif "cascadelake" in target: + return "vpdpbusd" + else: + assert False, "Target should be Skylake or Cascadelake" + + @tvm.testing.fixture + def assembly( + self, + target, + dtypes, + input_channels, + output_channels, + data_layout, + kernel_layout, + ): input_dtype, weight_dtype, output_dtype = dtypes - n, h, w, ch, cw = 1, 64, 64, 3, 3 + image_size = (64, 64) + kernel_size = (3, 3) + batch_size = 1 + + h, w = image_size + if data_layout == "NCHW": - data_shape = (n, ic, h, w) - x = relay.var("x", relay.TensorType(data_shape, input_dtype)) + data_shape = (batch_size, input_channels, *image_size) elif data_layout == "NHWC": - data_shape = (n, h, w, ic) - x = relay.var("x", relay.TensorType(data_shape, input_dtype)) + data_shape = (batch_size, *image_size, input_channels) else: - raise ValueError("Not supported") + raise ValueError(f"Unsupported data layout: {data_layout}") + x = relay.var("x", relay.TensorType(data_shape, input_dtype)) if kernel_layout == "OIHW": - kernel_shape = (oc, ic, ch, cw) + kernel_shape = (output_channels, input_channels, *kernel_size) elif kernel_layout == "HWIO": - kernel_shape = (ch, cw, ic, oc) + kernel_shape = (*kernel_size, input_channels, output_channels) else: raise ValueError("Not supported") - weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype)) + y = relay.nn.conv2d( x, weight, - kernel_size=(ch, cw), - channels=oc, + kernel_size=kernel_size, + channels=output_channels, padding=(0, 0, 0, 1), dilation=(1, 1), data_layout=data_layout, kernel_layout=kernel_layout, out_dtype=output_dtype, ) + func = relay.Function([x, weight], y) + wdata = np.random.rand(*kernel_shape) * 10 parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))} with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(func, target, params=parameters) - assembly = lib.get_source("asm") - return assembly - - def _has_fast_int8_instructions(asm, target): - if "nehalem" in target or "core-avx2" in target or "skylake-avx512" in target: - return "pmaddubs" in asm - elif "cascadelake" in target: - return "vpdpbusd" in asm - else: - assert False, "Target should be Skylake or Cascadelake" - - # TODO(@anijain2305, @icemelon9): disable conv2d_int8 for NHWC data layout. - # Re-enable this after adding conv2d_NCHWc_int8 support for NHWC. - - # compile conv2d for x86 (SSE3/AVX2/AVX512/VNNI capable) and test assembly contains *pmadd* instructions - targets = [ - "llvm -mcpu=nehalem", - "llvm -mcpu=core-avx2", - "llvm -mcpu=skylake-avx512", - "llvm -mcpu=cascadelake", - ] - llvm_version = tvm.target.codegen.llvm_version_major() - for target in targets: - if tvm.testing.device_enabled(target) and llvm_version >= 8: - dtypes = ("uint8", "int8", "int32") - # Sweep the input channels to check int8 robustness - # Input channels should be a multiple of 4 internally. - for ic in [1, 4, 6]: - asm = _compile( - ic=ic, - oc=16, - target=target, - data_layout="NCHW", - kernel_layout="OIHW", - dtypes=dtypes, - ) - assert _has_fast_int8_instructions(asm, target) - - # for ic in [1, 4, 6]: - # asm = _compile(ic=ic, oc=16, target=target, data_layout="NHWC", - # kernel_layout='HWIO', - # dtypes=dtypes) - # assert _has_fast_int8_instructions(asm, target) - - # Sweep the output channels to check int8 robustness - # Output channels should be a multiple of 16 internally. - for oc in [4, 16, 20]: - asm = _compile( - ic=8, - oc=oc, - target=target, - data_layout="NCHW", - kernel_layout="OIHW", - dtypes=dtypes, - ) - assert _has_fast_int8_instructions(asm, target) - - # for oc in [4, 16, 20]: - # asm = _compile(ic=8, oc=oc, target=target, data_layout="NHWC", - # kernel_layout='HWIO', - # dtypes=dtypes) - # assert _has_fast_int8_instructions(asm, target) - - # Check that both non-divisible oc and ic work - asm = _compile( - ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout="OIHW", dtypes=dtypes - ) - assert _has_fast_int8_instructions(asm, target) - - # asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO', - # dtypes=dtypes) - # assert _has_fast_int8_instructions(asm, target) - - # Check that int8 x int8 goes through legalization so that fast instructions can be picked up. - for target in targets: - if tvm.testing.device_enabled(target) and llvm_version >= 8: - dtypes = ("int8", "int8", "int32") - # Check that both non-divisible oc and ic work - asm = _compile( - ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout="OIHW", dtypes=dtypes - ) - assert _has_fast_int8_instructions(asm, target) - - # asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO', - # dtypes=dtypes) - # assert _has_fast_int8_instructions(asm, target) + return lib.get_source("asm") + + # Ensure that code uses the fast int8 instructions when available. + @tvm.testing.parametrize_targets(*supported_targets) + @pytest.mark.parametrize( + "dtypes", + [ + # compile conv2d for x86 (skylake, cascadelake) and test + # assembly contains *pmadd* instructions + ("uint8", "int8", "int32"), + # Check that int8 x int8 goes through legalization so that + # fast instructions can be picked up. + ("int8", "int8", "int32"), + ], + ) + def test_uses_intrinsic( + self, + fast_int8_intrinsic, + assembly, + ): + assert fast_int8_intrinsic in assembly - # Ensure that code is generated when datatypes are not HW supported. - # dtypes = ('uint8', 'uint8', 'int32') - # asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO', - # dtypes=dtypes) - # # Check that intrinisic is not present in the assembly. - # assert not _has_fast_int8_instructions(asm, target) + # For datatypes that don't have HW support, ensure that code is + # generated without the fast int8 intrinsic. + @tvm.testing.parametrize_targets(*supported_targets) + @pytest.mark.parametrize("dtypes", [("uint8", "uint8", "int32")]) + def test_no_intrinsic( + self, + fast_int8_intrinsic, + assembly, + ): + assert fast_int8_intrinsic not in assembly # Check that a vectorized instruction is generated for older Intel # generations, because we default to NCHWc layout. - target = "llvm -mcpu=x86-64" - if tvm.testing.device_enabled(target): - fast_int8_dtypes = ("uint8", "int8", "int32") - asm = _compile( - ic=16, - oc=32, - target=target, - data_layout="NCHW", - kernel_layout="OIHW", - dtypes=fast_int8_dtypes, - ) - # Check that vector int mult and add instructions are generated. - assert "pmulhw" in asm and "paddd" in asm + @tvm.testing.parametrize_targets(*unsupported_targets) + @pytest.mark.parametrize("dtypes", [("uint8", "int8", "int32")]) + def test_uses_vectorized_instruction(self, assembly): + assert "pmulhw" in assembly and "paddd" in assembly @tvm.testing.uses_gpu From e4946f470ca929ba350ebeaeb06b0812f705f186 Mon Sep 17 00:00:00 2001 From: Arun Abraham Date: Wed, 29 Sep 2021 10:57:21 +0530 Subject: [PATCH 10/20] [Frontend][PyTorch] support for quantized conv_transpose2d op (#9133) * [Frontend][PyTorch] support for quantized conv_transpose2d op PyTorch uses the same underlying function to pack and unpack the params for conv2d and conv_transpose2d ops. This change adds support for quantized conv_transpose2d op by reusing the ConvPackedParam and adding the output_padding param to it. This output_padding param will remain unused in case of conv2d. Also added test for above with specific condition for torch v1.7.1 and below. * fix after merging main --- python/tvm/relay/frontend/qnn_torch.py | 100 +++++++++++++++++++++- tests/python/frontend/pytorch/qnn_test.py | 26 +++++- 2 files changed, 121 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index af3c352d15ae..172ab1e41268 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -51,12 +51,25 @@ class ConvPackedParam(QNNParam): together with weights and quantization parameters """ - def __init__(self, weight_np, bias, scale, zero_point, stride, padding, dilation, groups): + def __init__( + self, + weight_np, + bias, + scale, + zero_point, + stride, + padding, + dilation, + groups, + output_padding, + ): super().__init__(weight_np, bias, scale, zero_point) self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups + # Used only for conv_transpose2d + self.output_padding = output_padding def _get_quant_params(qweight): @@ -86,7 +99,18 @@ def make_conv_packed_param(qweight, bias, packed_params): padding = packed_params.padding() dilation = packed_params.dilation() groups = packed_params.groups() - return ConvPackedParam(weight_np, bias, scale, zero_point, stride, padding, dilation, groups) + output_padding = packed_params.output_padding() + return ConvPackedParam( + weight_np, + bias, + scale, + zero_point, + stride, + padding, + dilation, + groups, + output_padding, + ) def get_weight_quant_params(script_module, packed_param_names): @@ -208,7 +232,13 @@ def add_quant_params_to_outputs( params = [qweight, qparam.scale, qparam.zero_point, qbias] if isinstance(quant_params[packed_param_name], ConvPackedParam): - params += [qparam.stride, qparam.padding, qparam.dilation, qparam.groups] + params += [ + qparam.stride, + qparam.padding, + qparam.dilation, + qparam.groups, + qparam.output_padding, + ] outputs[node_name] = params @@ -246,6 +276,7 @@ def _get_quant_param_for_input(input_value): "quantized::mul_scalar": (2, 3), "quantized::add_scalar": (2, 3), "quantized::hardswish": (1, 2), + "quantized::conv_transpose2d": qconv_indices, } def dfs(current_node): @@ -416,6 +447,7 @@ def add_input_quant_params_to_op_inputs(graph): "quantized::relu6": 1, "quantized::hardswish": 1, "aten::hardsigmoid": 1, + "quantized::conv_transpose2d": 1, } need_input_quant_param = set(num_quantized_inputs.keys()) @@ -457,7 +489,7 @@ def add_input_quant_params_to_op_inputs(graph): node.addInput(scale) node.addInput(zp) - if "conv2d" in operator or "linear" in operator: + if "conv" in operator or "linear" in operator: # This is required for quantizing the bias input_scales_for_bias[node.inputsAt(1).debugName()] = scale.node().f("value") @@ -983,6 +1015,65 @@ def _impl(inputs, _): return _impl +def _quantized_conv_transpose2d(with_relu=False): + def _impl(inputs, _): + # Refer to aten/src/ATen/native/quantized/cpu/qconv.cpp + # Supported in Torch 1.7 or newer + conv_params = inputs[1] + weight = conv_params[0] + weight_scale = conv_params[1] + weight_zero_point = conv_params[2] + bias = conv_params[3] + + strides = conv_params[4] + padding = conv_params[5] + dilation = conv_params[6] + groups = conv_params[7] + output_padding = conv_params[8] + + output_scale = _expr.const(inputs[2]) + output_zero_point = _expr.const(inputs[3]) + + assert len(inputs) == 6, "Input quant params not found in op inputs" + + # These are manually added by add_input_quant_params_to_op_inputs above + # In torch, they are retrieved from QTensor data structure at runtime + input_scale = _expr.const(inputs[4]) + input_zero_point = _expr.const(inputs[5]) + + weight_shape = list(infer_shape(weight)) + + # Swap I and O dims to match shape relay expects for OIHW + weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0] + + kernel_size = (weight_shape[2], weight_shape[3]) + out_channels = weight_shape[0] + + conv_out = relay.qnn.op.conv2d_transpose( + inputs[0], + weight, + input_zero_point, + weight_zero_point, + input_scale, + weight_scale, + kernel_size=kernel_size, + dilation=dilation, + strides=strides, + padding=padding, + groups=groups, + channels=out_channels, + output_padding=output_padding, + out_dtype="int32", + kernel_layout="OIHW", + ) + + return _do_bias_and_requantize( + conv_out, bias, input_scale, weight_scale, output_scale, output_zero_point, with_relu + ) + + return _impl + + convert_map = { "aten::quantize_per_tensor": _quantize_per_tensor(), "quantized::conv2d_relu": _quantized_conv2d(with_relu=True), @@ -1000,4 +1091,5 @@ def _impl(inputs, _): "quantized::relu6": _relu6(), "quantized::linear_dynamic": _linear_dynamic(), "quantized::hardswish": _hswish(), + "quantized::conv_transpose2d": _quantized_conv_transpose2d(), } diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 65e5692dc4fb..9f145b75a405 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -98,6 +98,20 @@ def fuse_model(self): fuse_modules(self.conv, indices, inplace=True) +class ConvTranspose(nn.Module): + def __init__(self): + super().__init__() + layers = [nn.ConvTranspose2d(3, 32, 3, bias=True)] + self.conv = nn.Sequential(*layers) + self.quant_wrap = QuantWrapper(self.conv) + + def forward(self, x): + return self.quant_wrap(x) + + def fuse_model(self): + pass + + class Linear(nn.Module): def __init__(self, with_relu=False): super().__init__() @@ -276,6 +290,7 @@ def test_quantized_modules(): ("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel), ("linear" + postfix, (16, 16), Linear(), per_channel), ("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel), + ("conv_transpose", imagenet_ishape, ConvTranspose(), False), ("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False), ("hswish", imagenet_ishape, Hswish(add_stub=True), False), ("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False), @@ -287,7 +302,15 @@ def test_quantized_modules(): raw_module.eval() inp = torch.rand(ishape) - quantize_model(raw_module, inp, per_channel=per_channel) + # quantized conv_transpose2d is supported only with qnnpack engine before torch v1.8.0. + if module_name == "conv_transpose" and not is_version_greater_than("1.7.1"): + prev_engine = torch.backends.quantized.engine + torch.backends.quantized.engine = "qnnpack" + quantize_model(raw_module, inp, per_channel=per_channel) + torch.backends.quantized.engine = prev_engine + else: + quantize_model(raw_module, inp, per_channel=per_channel) + script_module = torch.jit.trace(raw_module, inp).eval() with torch.no_grad(): @@ -314,6 +337,7 @@ def test_quantized_modules(): conv_bn_relu 0.3700896 0.010921672 0.7489366477964451 linear 0.15987062 0.009231662 0.794921875 linear_relu 0.14180502 0.0053220326 0.8828125 + conv_transpose 0.0033792555 4.4658788e-07 0.9998678439971806 conv_bn, per_channel 0.01654929 2.9486866e-06 0.9998218235127019 conv_bn_relu, per_channel 0.009089053 1.4926576e-06 0.9998357732732732 linear, per_channel 0.0 0.0 1.0 From a16ccf4967bc6fa1906c6df76371bb8f3cb26bfd Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 28 Sep 2021 22:39:41 -0700 Subject: [PATCH 11/20] [Meta Schedule][M3a] SearchStrategy (#9132) * Add c++ side SearchStrategy. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng * Add python-side code & test. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng * Add docs. * Minor fix. * Add workflow. * Add docs. * Fix docs. * Add notes. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- include/tvm/meta_schedule/builder.h | 4 +- include/tvm/meta_schedule/runner.h | 62 +++++ include/tvm/meta_schedule/search_strategy.h | 246 ++++++++++++++++++ include/tvm/meta_schedule/space_generator.h | 38 ++- include/tvm/support/random_engine.h | 10 + python/tvm/meta_schedule/__init__.py | 2 + python/tvm/meta_schedule/runner/__init__.py | 18 ++ python/tvm/meta_schedule/runner/runner.py | 59 +++++ .../meta_schedule/search_strategy/__init__.py | 20 ++ .../search_strategy/replay_trace.py | 47 ++++ .../search_strategy/search_strategy.py | 166 ++++++++++++ src/meta_schedule/runner/runner.cc | 41 +++ .../search_strategy/replay_trace.cc | 148 +++++++++++ .../search_strategy/search_strategy.cc | 68 +++++ src/meta_schedule/utils.h | 76 ++++++ src/tir/schedule/concrete_schedule.cc | 4 +- src/tir/schedule/primitive.h | 8 + src/tir/schedule/primitive/sampling.cc | 12 + .../test_meta_schedule_search_strategy.py | 98 +++++++ 19 files changed, 1121 insertions(+), 6 deletions(-) create mode 100644 include/tvm/meta_schedule/runner.h create mode 100644 include/tvm/meta_schedule/search_strategy.h create mode 100644 python/tvm/meta_schedule/runner/__init__.py create mode 100644 python/tvm/meta_schedule/runner/runner.py create mode 100644 python/tvm/meta_schedule/search_strategy/__init__.py create mode 100644 python/tvm/meta_schedule/search_strategy/replay_trace.py create mode 100644 python/tvm/meta_schedule/search_strategy/search_strategy.py create mode 100644 src/meta_schedule/runner/runner.cc create mode 100644 src/meta_schedule/search_strategy/replay_trace.cc create mode 100644 src/meta_schedule/search_strategy/search_strategy.cc create mode 100644 tests/python/unittest/test_meta_schedule_search_strategy.py diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index 9186c9d039e0..19358552df10 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -25,7 +25,7 @@ namespace tvm { namespace meta_schedule { -/*! \brief The builder's input. */ +/*! \brief The builder's input, containing an IRModule and the target. */ class BuilderInputNode : public runtime::Object { public: /*! \brief The IRModule to be built. */ @@ -57,7 +57,7 @@ class BuilderInput : public runtime::ObjectRef { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderInput, runtime::ObjectRef, BuilderInputNode); }; -/*! \brief The builder's output. */ +/*! \brief The builder's output, containing the artifact path or error message if any. */ class BuilderResultNode : public runtime::Object { public: /*! \brief The path to the built artifact. */ diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h new file mode 100644 index 000000000000..36d07024559d --- /dev/null +++ b/include/tvm/meta_schedule/runner.h @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_RUNNER_H_ +#define TVM_META_SCHEDULE_RUNNER_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +/*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */ +class RunnerResultNode : public runtime::Object { + public: + /*! \brief The run time in seconds. If not None, error_msg should be None. */ + Optional> run_secs; + /*! \brief The error message, if any. If not None, run_secs should be None. */ + Optional error_msg; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("run_secs", &run_secs); + v->Visit("error_msg", &error_msg); + } + + static constexpr const char* _type_key = "meta_schedule.RunnerResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(RunnerResultNode, runtime::Object); +}; + +/*! + * \brief Managed reference to RunnerResultNode + * \sa RunnerResultNode + */ +class RunnerResult : public runtime::ObjectRef { + public: + /*! + * \brief Constructor for RunnerResult. + * \param run_secs The run time in seconds. + * \param error_msg The error message, if any. + */ + TVM_DLL explicit RunnerResult(Optional> run_secs, Optional error_msg); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerResult, runtime::ObjectRef, RunnerResultNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_RUNNER_H_ diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h new file mode 100644 index 000000000000..941dae4336e1 --- /dev/null +++ b/include/tvm/meta_schedule/search_strategy.h @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ +#define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ + +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +// Forward declaration +class TuneContext; + +/*! \brief The schedule (with input shapes) to be measured. */ +class MeasureCandidateNode : public runtime::Object { + public: + /*! \brief The schedule for measurement. */ + tir::Schedule sch; + /*! \brief The argument information, e.g., (shape, dtype) for tensors. */ + Array args_info; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("sch", &sch); + v->Visit("args_info", &args_info); + } + + static constexpr const char* _type_key = "meta_schedule.MeasureCandidate"; + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object); +}; + +/*! + * \brief Managed reference to MeasureCandidateNode. + * \sa MeasureCandidateNode + */ +class MeasureCandidate : public runtime::ObjectRef { + public: + /*! + * \brief Constructor of MeasureCandidate. + * \param sch The schedule for measurement. + * \param args_info The argument information, e.g., (shape, dtype) for tensors. + */ + TVM_DLL MeasureCandidate(tir::Schedule sch, Array args_info); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode); +}; + +/*! + * \brief The search strategy for measure candidates generation. + * \note The relationship between SearchStrategy and other classes are as follows: + ┌──────────────────────────────────────────────────────────────┐ + ┌──┴───────────────────────────────────────────────────────────┐ │ +┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ +│ ┌─────────────────────┐ │ │ │ +│ │ │ Generate │ │ │ +│ │ Space Generator ├──────────────┐ │ │ │ +│ │ │ │ │ │ │ +│ └─────────────────────┘ ▼ │ │ │ +│ Design Space │ │ │ +│ ┌─────────────────────┐ │ │ │ │ +│ Generate │ │ Pretuning │ │ │ │ +│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ +│ │ │ │ │ ├──┘ +│ │ └─────────────────────┘ ├──┘ +└────┼─────────────────────────────────────────────────────────┘ + │ + │ +┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ +│ │ ┌───────────┐ │ +│ │ Send to │ │ Send to │ +│ ▼ ┌─────────────►│ Builder ├──────────┐ │ +│ Measure Candidate │ Builder │ │ Runner │ │ +│ │ │ └───────────┘ │ │ +│ │ ┌────────────┴────────┐ │ │ +│ │ │ │ ┌───────────┐ │ │ +│ └────►│ Task Scheduler │ │ │ │ │ +│ │ │ │ Runner │◄─────────┘ │ +│ └─────────────────────┘ │ │ │ +│ ▲ └─────┬─────┘ │ +│ │ │ │ +│ └─── Runner Future ◄────┘ │ +└─────────────────────────────────────────────────────────────────────┘ +*/ +class SearchStrategyNode : public runtime::Object { + public: + /*! \brief Virtual destructor */ + virtual ~SearchStrategyNode() = default; + + /*! + * \brief Initialize the search strategy with tuning context. + * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. + */ + virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0; + + /*! + * \brief Pre-tuning for the search strategy. + * \param design_spaces The design spaces for pre-tuning. + * \note Pre-tuning is supposed to be called before the tuning process and after the + * initialization. Because the search strategy is stateful, we can always call pretuning + * and reset the search strategy. + */ + virtual void PreTuning(const Array& design_spaces) = 0; + + /*! + * \brief Post-tuning for the search strategy. + * \note Post-tuning is supposed to be called after the tuning process and before we reset the + * search strategy with another pre-tuning. Post-tuning can be empty. + */ + virtual void PostTuning() = 0; + + /*! + * \brief Generate measure candidates from design spaces for measurement. + * \return The measure candidates generated, nullptr if finished. + */ + virtual Optional> GenerateMeasureCandidates() = 0; + + /*! + * \brief Update the search strategy with measurement results. + * \param results The measurement results from the runner. + */ + virtual void NotifyRunnerResults(const Array& results) = 0; + + static constexpr const char* _type_key = "meta_schedule.SearchStrategy"; + TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object); +}; + +/*! \brief The python side customizable class for measure candidate generation */ +class PySearchStrategyNode : public SearchStrategyNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief The function type of `PreTuning` method. + * \param design_spaces The design spaces for pre-tuning. + */ + using FPreTuning = runtime::TypedPackedFunc&)>; + /*! \brief The function type of `PostTuning` method. */ + using FPostTuning = runtime::TypedPackedFunc; + /*! + * \brief The function type of `GenerateMeasureCandidates` method. + * \return The measure candidates generated, nullptr if finished. + */ + using FGenerateMeasureCandidates = runtime::TypedPackedFunc>()>; + /*! + * \brief The function type of `NotifyRunnerResults` method. + * \param results The measurement results from the runner. + */ + using FNotifyRunnerResults = runtime::TypedPackedFunc&)>; + + /*! \brief The packed function to the `InitializeWithTuneContext` method. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `PreTuning` method. */ + FPreTuning f_pre_tuning; + /*! \brief The packed function to the `PostTuning` method. */ + FPostTuning f_post_tuning; + /*! \brief The packed function to the `GenerateMeasureCandidates` method. */ + FGenerateMeasureCandidates f_generate_measure_candidates; + /*! \brief The packed function to the `NotifyRunnerResults` method. */ + FNotifyRunnerResults f_notify_runner_results; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_pre_tuning` is not visited + // `f_post_tuning` is not visited + // `f_generate_measure_candidates` is not visited + // `f_notify_runner_results` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + this->f_initialize_with_tune_context(context); + } + + void PreTuning(const Array& design_spaces) final { + this->f_pre_tuning(design_spaces); + } + + void PostTuning() final { this->f_post_tuning(); } + + Optional> GenerateMeasureCandidates() final { + return this->f_generate_measure_candidates(); + } + + void NotifyRunnerResults(const Array& results) final { + this->f_notify_runner_results(results); + } + + static constexpr const char* _type_key = "meta_schedule.PySearchStrategy"; + TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode); +}; + +/*! + * \brief Managed reference to SearchStrategyNode. + * \sa SearchStrategyNode + */ +class SearchStrategy : public runtime::ObjectRef { + public: + /*! + * \brief Create a search strategy with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_pre_tuning The packed function of `PreTuning`. + * \param f_post_tuning The packed function of `PostTuning`. + * \param f_generate_measure_candidates The packed function of `GenerateMeasureCandidates`. + * \param f_notify_runner_results The packed function of `NotifyRunnerResults`. + * \return The search strategy created. + */ + TVM_DLL static SearchStrategy PySearchStrategy( + PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PySearchStrategyNode::FPreTuning f_pre_tuning, // + PySearchStrategyNode::FPostTuning f_post_tuning, // + PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // + PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results); + + /*! + * \brief Constructor of replay trace search strategy. + * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. + * \param num_trials_total The total number of trials for trace replaying. + */ + TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 9528be2a85ad..3dc181e05d8a 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -28,7 +28,42 @@ namespace meta_schedule { // Forward declaration class TuneContext; -/*! \brief The abstract class for design space generation. */ +/*! + * \brief The abstract class for design space generation. + * \note The relationship between SpaceGenerator and other classes are as follows: + ┌──────────────────────────────────────────────────────────────┐ + ┌──┴───────────────────────────────────────────────────────────┐ │ +┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ +│ ┌─────────────────────┐ │ │ │ +│ │ │ Generate │ │ │ +│ │ Space Generator ├──────────────┐ │ │ │ +│ │ │ │ │ │ │ +│ └─────────────────────┘ ▼ │ │ │ +│ Design Space │ │ │ +│ ┌─────────────────────┐ │ │ │ │ +│ Generate │ │ Pretuning │ │ │ │ +│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ +│ │ │ │ │ ├──┘ +│ │ └─────────────────────┘ ├──┘ +└────┼─────────────────────────────────────────────────────────┘ + │ + │ +┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ +│ │ ┌───────────┐ │ +│ │ Send to │ │ Send to │ +│ ▼ ┌─────────────►│ Builder ├──────────┐ │ +│ Measure Candidate │ Builder │ │ Runner │ │ +│ │ │ └───────────┘ │ │ +│ │ ┌────────────┴────────┐ │ │ +│ │ │ │ ┌───────────┐ │ │ +│ └────►│ Task Scheduler │ │ │ │ │ +│ │ │ │ Runner │◄─────────┘ │ +│ └─────────────────────┘ │ │ │ +│ ▲ └─────┬─────┘ │ +│ │ │ │ +│ └─── Runner Future ◄────┘ │ +└─────────────────────────────────────────────────────────────────────┘ +*/ class SpaceGeneratorNode : public Object { public: /*! \brief Default destructor */ @@ -37,6 +72,7 @@ class SpaceGeneratorNode : public Object { /*! * \brief Initialize the design space generator with tuning context. * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. */ virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0; diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index 6b733d074f6a..fcd2326050ed 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -102,6 +102,16 @@ class LinearCongruentialEngine { *rand_state_ptr_ = rand_state; // Change pointed random state to given random state value. } + /*! + * \brief Fork a new seed for another RNG from current random state. + * \return The forked seed. + */ + TRandState ForkSeed() { + // In order for reproducibility, we computer the new seed using RNG's random state and a + // different set of parameters. Note that both 32767 and 1999999973 are prime numbers. + return ((*this)() * 32767) % 1999999973; + } + /*! * \brief Construct a random number generator with a random state pointer. * \param rand_state_ptr The random state pointer given in result_type*. diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index f8b2b026c83b..c22cc205bf35 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -19,5 +19,7 @@ from . import builder from . import database from . import space_generator +from . import search_strategy +from . import runner from .database import TuningRecord from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/runner/__init__.py b/python/tvm/meta_schedule/runner/__init__.py new file mode 100644 index 000000000000..65d2ef04e04c --- /dev/null +++ b/python/tvm/meta_schedule/runner/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""meta_schedule.runner""" +from .runner import RunnerResult diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py new file mode 100644 index 000000000000..b756c6e6b011 --- /dev/null +++ b/python/tvm/meta_schedule/runner/runner.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Runners""" +from typing import List, Optional + +from tvm._ffi import register_object +from tvm.runtime import Object + +from .. import _ffi_api + + +@register_object("meta_schedule.RunnerResult") +class RunnerResult(Object): + """The runner's result + + Parameters + ---------- + run_secs : Optional[List[float]] + The run time in seconds. + error_msg : Optional[str] + The error message, if any. + """ + + run_secs: Optional[List[float]] + error_msg: Optional[str] + + def __init__( + self, + run_secs: Optional[List[float]], + error_msg: Optional[str], + ) -> None: + """Constructor + + Parameters + ---------- + run_secs : Optional[List[float]] + The run time in seconds. + error_msg : Optional[str] + The error message, if any. + """ + self.__init_handle_by_constructor__( + _ffi_api.RunnerResult, # type: ignore # pylint: disable=no-member + run_secs, + error_msg, + ) diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py new file mode 100644 index 000000000000..40f21da0b2d1 --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Search Strategy""" + +from .search_strategy import SearchStrategy, PySearchStrategy +from .replay_trace import ReplayTrace diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py new file mode 100644 index 000000000000..3afdff6de77e --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Replay Trace Search Strategy""" + +from tvm._ffi import register_object +from .search_strategy import SearchStrategy +from .. import _ffi_api + + +@register_object("meta_schedule.ReplayTrace") +class ReplayTrace(SearchStrategy): + """ + Replay Trace Search Strategy is a search strategy that always replays the trace by removing its + decisions so that the decisions would be randomly re-generated. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + """ + + num_trials_per_iter: int + num_trials_total: int + + def __init__(self, num_trials_per_iter: int, num_trials_total: int): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.ReplayTrace, # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + ) diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py new file mode 100644 index 000000000000..72713155c41d --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Search Strategy""" + +from typing import List, Optional, TYPE_CHECKING + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Schedule + +from .. import _ffi_api +from ..arg_info import ArgInfo +from ..runner import RunnerResult + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.MeasureCandidate") +class MeasureCandidate(Object): + """Measure candidate class. + + Parameters + ---------- + sch : Schedule + The schedule to be measured. + args_info : List[ArgInfo] + The argument information. + """ + + sch: Schedule + args_info: List[ArgInfo] + + def __init__(self, sch: Schedule, args_info: List[ArgInfo]) -> None: + """Constructor. + + Parameters + ---------- + sch : Schedule + The schedule to be measured. + args_info : List[ArgInfo] + The argument information. + """ + self.__init_handle_by_constructor__( + _ffi_api.MeasureCandidate, # pylint: disable=no-member + sch, + args_info, + ) + + +@register_object("meta_schedule.SearchStrategy") +class SearchStrategy(Object): + """ + Search strategy is the class that generates the measure candidates. It has to be pre-tuned + before usage and post-tuned after usage. + """ + + def initialize_with_tune_context( + self, + tune_context: "TuneContext", + ) -> None: + """Initialize the search strategy with tuning context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initialization. + """ + _ffi_api.SearchStrategyInitializeWithTuneContext( # pylint: disable=no-member + self, tune_context + ) + + def pre_tuning(self, design_spaces: List[Schedule]) -> None: + """Pre-tuning for the search strategy. + + Parameters + ---------- + design_spaces : List[Schedule] + The design spaces for pre-tuning. + """ + _ffi_api.SearchStrategyPreTuning(self, design_spaces) # pylint: disable=no-member + + def post_tuning(self) -> None: + """Post-tuning for the search strategy.""" + _ffi_api.SearchStrategyPostTuning(self) # pylint: disable=no-member + + def generate_measure_candidates(self) -> Optional[List[MeasureCandidate]]: + """Generate measure candidates from design spaces for measurement. + + Returns + ------- + measure_candidates : Optional[List[IRModule]] + The measure candidates generated, None if finished. + """ + return _ffi_api.SearchStrategyGenerateMeasureCandidates(self) # pylint: disable=no-member + + def notify_runner_results(self, results: List[RunnerResult]) -> None: + """Update the search strategy with profiling results. + + Parameters + ---------- + results : List[RunnerResult] + The profiling results from the runner. + """ + _ffi_api.SearchStrategyNotifyRunnerResults(self, results) # pylint: disable=no-member + + +@register_object("meta_schedule.PySearchStrategy") +class PySearchStrategy(SearchStrategy): + """An abstract search strategy with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + def f_initialize_with_tune_context(context: "TuneContext") -> None: + self.initialize_with_tune_context(context) + + def f_pre_tuning(design_spaces: List[Schedule]) -> None: + self.pre_tuning(design_spaces) + + def f_post_tuning() -> None: + self.post_tuning() + + def f_generate_measure_candidates() -> List[MeasureCandidate]: + return self.generate_measure_candidates() + + def f_notify_runner_results(results: List["RunnerResult"]) -> None: + self.notify_runner_results(results) + + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyPySearchStrategy, # pylint: disable=no-member + f_initialize_with_tune_context, + f_pre_tuning, + f_post_tuning, + f_generate_measure_candidates, + f_notify_runner_results, + ) + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + raise NotImplementedError + + def pre_tuning(self, design_spaces: List[Schedule]) -> None: + raise NotImplementedError + + def post_tuning(self) -> None: + raise NotImplementedError + + def generate_measure_candidates(self) -> List[MeasureCandidate]: + raise NotImplementedError + + def notify_runner_results(self, results: List["RunnerResult"]) -> None: + raise NotImplementedError diff --git a/src/meta_schedule/runner/runner.cc b/src/meta_schedule/runner/runner.cc new file mode 100644 index 000000000000..8f509bdd7b84 --- /dev/null +++ b/src/meta_schedule/runner/runner.cc @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +RunnerResult::RunnerResult(Optional> run_secs, Optional error_msg) { + ObjectPtr n = make_object(); + n->run_secs = run_secs; + n->error_msg = error_msg; + this->data_ = n; +} + +TVM_REGISTER_NODE_TYPE(RunnerResultNode); + +TVM_REGISTER_GLOBAL("meta_schedule.RunnerResult") + .set_body_typed([](Array run_secs, Optional error_msg) -> RunnerResult { + return RunnerResult(run_secs, error_msg); + }); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc new file mode 100644 index 000000000000..1c83aee8c0fd --- /dev/null +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief A search strategy that generates measure candidates using trace and random decisions. */ +class ReplayTraceNode : public SearchStrategyNode { + public: + using TRandState = support::LinearCongruentialEngine::TRandState; + + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + ReplayTraceNode* self; + /*! \brief The design spaces. */ + Array design_spaces; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(ReplayTraceNode* self, Array design_spaces) + : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} + + inline Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const Array& results); + }; + + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + + /*! \brief The module to be tuned. */ + IRModule mod_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief The number of threads to use. -1 means using logical cpu number. */ + int num_threads_ = -1; + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_trials_per_iter", &num_trials_per_iter); + v->Visit("num_trials_total", &num_trials_total); + // `mod_` is not visited + // `args_info_` is not visited + // `num_threads_` is not visited + // `rand_state_` is not visited + // `state_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.ReplayTrace"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + this->mod_ = tune_context->mod.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(this->mod_)); + this->num_threads_ = tune_context->num_threads; + this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(!design_spaces.empty()); + ICHECK(this->state_ == nullptr); + this->state_ = std::make_unique(this, design_spaces); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(results); + } +}; + +inline Optional> ReplayTraceNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + return NullOpt; + } + ed = std::min(ed, self->num_trials_total); + ICHECK_LT(st, ed); + std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); + Array per_task_result(ed - st, MeasureCandidate{nullptr}); + auto f_worker = [this, &per_thread_rand_state, &per_task_result](int thread_id, + int task_id) -> void { + TRandState& rand_state = per_thread_rand_state[thread_id]; + int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); + tir::Trace trace = design_spaces[design_space_index]->trace().value(); + tir::Trace new_trace = tir::Trace(trace->insts, {}); + tir::Schedule sch = tir::Schedule::Traced( // + self->mod_, // + /*rand_state=*/ForkSeed(&rand_state), // + /*debug_mode=*/0, // + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + new_trace->ApplyToSchedule(sch, /*remove_postproc=*/true); + per_task_result.Set(task_id, MeasureCandidate(sch, self->args_info_)); + }; + support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker); + return per_task_result; +} + +inline void ReplayTraceNode::State::NotifyRunnerResults(const Array& results) { + st += self->num_trials_per_iter; + ed += self->num_trials_per_iter; +} + +SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int num_trials_total) { + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(ReplayTraceNode); +TVM_REGISTER_GLOBAL("meta_schedule.ReplayTrace").set_body_typed(SearchStrategy::ReplayTrace); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc new file mode 100644 index 000000000000..fefe8dfce76e --- /dev/null +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +MeasureCandidate::MeasureCandidate(tir::Schedule sch, Array args_info) { + ObjectPtr n = make_object(); + n->sch = sch; + n->args_info = args_info; + data_ = std::move(n); +} + +SearchStrategy SearchStrategy::PySearchStrategy( + PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PySearchStrategyNode::FPreTuning f_pre_tuning, // + PySearchStrategyNode::FPostTuning f_post_tuning, // + PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // + PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = f_initialize_with_tune_context; + n->f_pre_tuning = f_pre_tuning; + n->f_post_tuning = f_post_tuning; + n->f_generate_measure_candidates = f_generate_measure_candidates; + n->f_notify_runner_results = f_notify_runner_results; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(MeasureCandidateNode); +TVM_REGISTER_OBJECT_TYPE(SearchStrategyNode); +TVM_REGISTER_NODE_TYPE(PySearchStrategyNode); + +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCandidate") + .set_body_typed([](tir::Schedule sch, Array args_info) -> MeasureCandidate { + return MeasureCandidate(sch, args_info); + }); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPySearchStrategy") + .set_body_typed(SearchStrategy::PySearchStrategy); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyInitializeWithTuneContext") + .set_body_method(&SearchStrategyNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPreTuning") + .set_body_method(&SearchStrategyNode::PreTuning); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPostTuning") + .set_body_method(&SearchStrategyNode::PostTuning); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyGenerateMeasureCandidates") + .set_body_method(&SearchStrategyNode::GenerateMeasureCandidates); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyNotifyRunnerResults") + .set_body_method(&SearchStrategyNode::NotifyRunnerResults); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 4c9e1e2c10a1..30294b8f91e1 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -23,16 +23,22 @@ #include #include #include +#include +#include #include #include #include #include +#include #include #include +#include +#include "../printer/text_printer.h" #include "../support/array.h" #include "../support/base64.h" +#include "../tir/schedule/primitive.h" namespace tvm { namespace meta_schedule { @@ -131,6 +137,76 @@ inline String JSONObj2Str(const ObjectRef& json_obj) { */ inline String SHash2Str(Workload::THashCode hash_code) { return std::to_string(hash_code); } +/*! + * \brief Find the entry function of the given IRModule, i.e, functions marked by + * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. + * \param mod The IRModule to find the entry function. + * \return The entry function. + */ +inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { + // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc` + int num_prim_func = 0; + const tir::PrimFuncNode* main_func = nullptr; + const tir::PrimFuncNode* last_func = nullptr; + for (const auto& kv : mod->functions) { + GlobalVar gv = kv.first; + BaseFunc base_func = kv.second; + if (const auto* func = base_func.as()) { + last_func = func; + if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + return GetRef(func); + } + if (gv->name_hint == "main") { + main_func = func; + } + ++num_prim_func; + } + } + // Priority 2: PrimFunc whose name is `main` + if (main_func != nullptr) { + return GetRef(main_func); + } + // Priority 3: The only PrimFunc in the IRModule + if (num_prim_func == 0) { + LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: " + << tir::AsTVMScript(mod); + } + if (num_prim_func > 1) { + LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but none of them are " + "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`" + << tir::AsTVMScript(mod); + } + return GetRef(last_func); +} + +/*! + * \brief Fork a random state into another, i.e. PRNG splitting. + * The given random state is also mutated. + * \param rand_state The random state to be forked + * \return The forked random state + */ +inline support::LinearCongruentialEngine::TRandState ForkSeed( + support::LinearCongruentialEngine::TRandState* rand_state) { + return support::LinearCongruentialEngine(rand_state).ForkSeed(); +} + +/*! + * \brief Fork a random state into another ones, i.e. PRNG splitting. + * The given random state is also mutated. + * \param rand_state The random state to be forked + * \param n The number of forks + * \return The forked random states + */ +inline std::vector ForkSeed( + support::LinearCongruentialEngine::TRandState* rand_state, int n) { + std::vector results; + results.reserve(n); + for (int i = 0; i < n; ++i) { + results.push_back(support::LinearCongruentialEngine(rand_state).ForkSeed()); + } + return results; +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 07af73ebabb6..93eba520f9d1 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -220,9 +220,7 @@ void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState se } support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { - // In order for reproducibility, we computer the new seed using RNG's random state and a different - // set of parameters. Note that both 32767 and 1999999973 are prime numbers. - return (support::LinearCongruentialEngine(&rand_state_)() * 32767) % 1999999973; + return support::LinearCongruentialEngine(&rand_state_).ForkSeed(); } ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 8ad6bdf7d37f..8d8acd2693f4 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -26,6 +26,14 @@ namespace tvm { namespace tir { /******** Schedule: Sampling ********/ +/*! + * \brief Sample a random integer from a given range. + * \param min_inclusive The minimum value of the range, inclusive. + * \param max_exclusive The maximum value of the range, exclusive. + * \return The random integer sampled in the given range. + */ +TVM_DLL int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int min_inclusive, + int max_exclusive); /*! * \brief Sample once category from candidates according to the probability weights. * \param self The schedule to update diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 8843ac613179..6ac6226118cd 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -24,6 +24,18 @@ namespace tvm { namespace tir { +int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int min_inclusive, + int max_exclusive) { + CHECK(min_inclusive < max_exclusive) + << "ValueError: max_exclusive must be greater than min_inclusive."; + if (min_inclusive + 1 == max_exclusive) { + return min_inclusive; + } + support::LinearCongruentialEngine rand_(rand_state); + std::uniform_int_distribution dist(min_inclusive, max_exclusive - 1); + return dist(rand_); +} + int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision) { diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py new file mode 100644 index 000000000000..6e90bddb84b4 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule SearchStrategy """ +# pylint: disable=missing-function-docstring +from typing import List + +import sys + +import pytest + +import tvm +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.space_generator import ScheduleFn +from tvm.meta_schedule.search_strategy import SearchStrategy, ReplayTrace + +from tvm.script import ty +from tvm.tir.schedule import Schedule, Trace + + +MATMUL_M = 32 + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking +# fmt: off + +@tvm.script.tir +class Matmul: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + tir.func_attr({"global_symbol": "main"}) + A = tir.match_buffer(a, (32, 32), "float32") + B = tir.match_buffer(b, (32, 32), "float32") + C = tir.match_buffer(c, (32, 32), "float32") + with tir.block([32, 32, tir.reduce_axis(0, 32)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _is_trace_equal(sch_1: Schedule, sch_2: Schedule) -> bool: + trace_1 = Trace(sch_1.trace.insts, {}) + trace_2 = Trace(sch_2.trace.insts, {}) + return str(trace_1) == str(trace_2) + + +def _schedule_matmul(sch: Schedule): + block = sch.get_block("matmul") + i, j, k = sch.get_loops(block=block) + # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + + +def test_meta_schedule_replay_trace(): + num_trials_per_iter = 7 + num_trials_total = 20 + + (example_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul()) + replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) + tune_context = TuneContext(mod=Matmul()) + replay.initialize_with_tune_context(tune_context) + + num_trials_each_round: List[int] = [] + replay.pre_tuning([example_sch]) + while True: + candidates = replay.generate_measure_candidates() + if candidates is None: + break + num_trials_each_round.append(len(candidates)) + runner_results: List[RunnerResult] = [] + for candidate in candidates: + assert _is_trace_equal(candidate.sch, example_sch) + runner_results.append(RunnerResult(run_secs=[0.5, 0.4, 0.3], error_msg=None)) + replay.notify_runner_results(runner_results) + replay.post_tuning() + assert num_trials_each_round == [7, 7, 6] + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 88313d728354a6e03a4174a8a2b4b9eabbf0a200 Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Tue, 28 Sep 2021 23:29:30 -0700 Subject: [PATCH 12/20] fix things (#9146) Co-authored-by: Andrew Zhao Luo --- python/tvm/relay/frontend/onnx.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 4444b15dfb12..ba2c6b4b54e7 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1468,18 +1468,26 @@ class Unsqueeze(OnnxOpConverter): """Operator converter for Unsqueeze.""" @classmethod - def _impl_v1(cls, inputs, attr, params): - axes = sorted(attr["axes"]) + def run_calculation(cls, tensor, axes): + axes = sorted(axes) for axis in axes: - inputs[0] = _op.expand_dims(inputs[0], axis=axis, num_newaxis=1) - return inputs[0] + tensor = _op.expand_dims(tensor, axis=axis, num_newaxis=1) + return tensor @classmethod - def _impl_v12(cls, inputs, attr, params): + def _impl_v1(cls, inputs, attr, params): + return cls.run_calculation(inputs[0], attr["axes"]) + + @classmethod + def _impl_v13(cls, inputs, attr, params): + if isinstance(inputs[1], _expr.Constant): + constant_axes = list(inputs[1].data.numpy()) + constant_axes = list(map(int, constant_axes)) + return cls.run_calculation(inputs[0], constant_axes) + rank_input = len(infer_type(inputs[0]).checked_type.shape) num_new_axis = int(infer_type(inputs[1]).checked_type.shape[0]) axes = relay.split(inputs[1], num_new_axis).astuple() - result = inputs[0] # TODO (AndrewZhaoLuo): investigate performance issues with consecutive From 198a8ab4f124b5147fe275754162f042e594743f Mon Sep 17 00:00:00 2001 From: wrongtest Date: Wed, 29 Sep 2021 16:49:04 +0800 Subject: [PATCH 13/20] [TIR] add loop partition hint pragma (#9121) * add loop partition hint pragma * fix unintialized var * fix to remove hint at last * use tir compare for loop partition testcase --- include/tvm/tir/stmt.h | 6 + src/tir/transforms/loop_partition.cc | 106 +++++++++++++----- .../test_tir_transform_loop_partition.py | 31 ++++- 3 files changed, 116 insertions(+), 27 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 0da8e55be023..2ae2877b2f92 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1339,6 +1339,12 @@ constexpr const char* hand_threaded = "hand_threaded"; * if (mask & 2) the write region should be detected. */ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access"; + +/*! + * \brief Mark that the loop should be partitioned. + */ +constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 97f5b6f90a70..c4b83e05706d 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -98,7 +98,13 @@ class CandidateSelector final : public StmtExprVisitor { void VisitStmt_(const ForNode* op) final { // partition const loop when sets partition_const_loop_ if (!is_const_int(op->min) || !is_const_int(op->extent) || partition_const_loop_) { + // always treat var with hint to be partitioned const VarNode* var = op->loop_var.get(); + if (partition_hint_vars.count(var)) { + candidates.insert(GetRef(op)); + StmtExprVisitor::VisitStmt_(op); + return; + } record_.insert({var, false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var) && !no_split_) { @@ -117,6 +123,12 @@ class CandidateSelector final : public StmtExprVisitor { Var var = iv->var; runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); if ((scope.rank == 0) && (!is_const_int(op->value) || partition_const_loop_)) { + // always treat var with hint to be partitioned + if (partition_hint_vars.count(var.get())) { + candidates.insert(GetRef(op)); + StmtExprVisitor::VisitStmt_(op); + return; + } record_.insert({var.get(), false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var.get()) && !no_split_) { @@ -125,6 +137,15 @@ class CandidateSelector final : public StmtExprVisitor { record_.erase(var.get()); return; } + } else if (op->attr_key == attr::pragma_loop_partition_hint) { + const VarNode* var = nullptr; + if (op->node->IsInstance()) { + var = op->node.as(); + } else if (op->node->IsInstance()) { + var = op->node.as()->var.get(); + } + ICHECK(var); + partition_hint_vars.insert(var); } StmtExprVisitor::VisitStmt_(op); } @@ -162,6 +183,7 @@ class CandidateSelector final : public StmtExprVisitor { } std::unordered_set candidates; + std::unordered_set partition_hint_vars; private: bool in_likely_{false}; @@ -170,15 +192,28 @@ class CandidateSelector final : public StmtExprVisitor { std::unordered_map record_; }; +// Finder try best to find partitions for hinted vars +#define DEFINE_PARTITION_FINDER_VISIT_CMP_OP(OpNodeT) \ + void VisitExpr_(const OpNodeT* op) final { \ + if (has_partition_hint_) { \ + DeduceCondition(GetRef(op)); \ + return; \ + } \ + StmtExprVisitor::VisitExpr_(op); \ + } + // Populate partitions data structure, i.e., for a specific variable, -// find an interval in which each condition -// (currently, "likely" conditions) has fixed true or false value +// find an interval in which each condition has fixed true or false value class PartitionFinder : public StmtExprVisitor { public: explicit PartitionFinder(Var current_var, const std::unordered_map& hint_map, - const std::unordered_map& relax_map) - : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) { + const std::unordered_map& relax_map, + bool has_partition_hint) + : current_var_(current_var), + has_partition_hint_(has_partition_hint), + hint_map_(hint_map), + relax_map_(relax_map) { for (const auto& kv : hint_map) { out_vars_.insert(kv.first); } @@ -218,33 +253,43 @@ class PartitionFinder : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::likely())) { - PrimExpr cond = op->args[0]; - if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) { - // For cond, find out the interval, if exists, in which we can prove that cond is - // true. Also find the interval, if exists, in which we can prove that cond is - // false. - IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_); - if (!interval.IsNothing()) { - // cond is true within interval - partitions[{cond, true}] = interval; - } - PrimExpr inverse_cond = InverseCond(cond); - if (inverse_cond.defined()) { - IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); - if (!interval.IsNothing()) { - // cond is false within interval - partitions[{cond, false}] = interval; - } - } - } + DeduceCondition(op->args[0]); } else { StmtExprVisitor::VisitExpr_(op); } } + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GENode); + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GTNode); + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LENode); + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LTNode); + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(EQNode); + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(NENode); + Partition partitions; private: + void DeduceCondition(const PrimExpr& cond) { + // For cond, find out the interval, if exists, in which we can prove that cond is + // true. Also find the interval, if exists, in which we can prove that cond is + // false. + if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) { + IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_); + if (!interval.IsNothing()) { + // cond is true within interval + partitions[{cond, true}] = interval; + } + PrimExpr inverse_cond = InverseCond(cond); + if (inverse_cond.defined()) { + IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); + if (!interval.IsNothing()) { + // cond is false within interval + partitions[{cond, false}] = interval; + } + } + } + } + PrimExpr InverseCond(const PrimExpr& cond) { PrimExpr inverse_cond; if (const LTNode* op = cond.as()) { @@ -270,6 +315,7 @@ class PartitionFinder : public StmtExprVisitor { } Var current_var_; + bool has_partition_hint_; std::unordered_set out_vars_; std::unordered_map hint_map_; std::unordered_map relax_map_; @@ -472,7 +518,8 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim // include hint of var. hint_map_.insert({var.get(), IntSet::Interval(min, max)}); - PartitionFinder finder(var, hint_map_, relax_map_); + bool has_partition_hint_ = selector.partition_hint_vars.count(var.get()); + PartitionFinder finder(var, hint_map_, relax_map_, has_partition_hint_); finder(body); hint_map_.erase(var.get()); @@ -601,7 +648,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b } } -class RemoveLikelyTags : public StmtExprMutator { +class RemoveLikelyTagsAndHints : public StmtExprMutator { public: PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::likely())) { @@ -611,12 +658,19 @@ class RemoveLikelyTags : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::pragma_loop_partition_hint) { + return VisitStmt(op->body); + } + return StmtExprMutator::VisitStmt_(op); + } }; Stmt LoopPartition(Stmt stmt, bool partition_const_loop, bool no_unroll_loop_with_extent_one) { stmt = LoopPartitioner(partition_const_loop, no_unroll_loop_with_extent_one) .VisitAndMutate(std::move(stmt)); - stmt = RemoveLikelyTags()(std::move(stmt)); + stmt = RemoveLikelyTagsAndHints()(std::move(stmt)); return stmt; } diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index c632f744bb81..a219b8d96457 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -17,6 +17,8 @@ import tvm import tvm.testing from tvm import te +from tvm import tir +from tvm.script import ty import numpy @@ -434,7 +436,6 @@ def test_conv_tiling(): oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16) bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.tir.transform.LoopPartition()(mod) @@ -538,6 +539,33 @@ def test_simple_rfactor(): assert not tvm.ir.structural_equal(stmt1.body, stmt2.body) +@tvm.script.tir +def partitioned_concat(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, [16], dtype="float32") + B = tir.match_buffer(b, [16], dtype="float32") + C = tir.match_buffer(c, [32], dtype="float32") + for i in tir.serial(0, 16): + tir.store(C.data, i, tir.load("float32", A.data, i), True) + for i in tir.serial(0, 16): + tir.store(C.data, i + 16, tir.load("float32", B.data, i + 16), True) + + +def test_explicit_partition_hint(): + A = te.placeholder((16,), name="A") + B = te.placeholder((16,), name="B") + C = te.compute((32,), lambda i: te.if_then_else(i < 16, A[i], B[i]), name="C") + s = te.create_schedule(C.op) + s.normalize() + s[C].pragma(s[C].op.axis[0], "loop_partition_hint") + mod = tvm.driver.build_module.schedule_to_module(s, [A, B, C], "main", None) + with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.LoopPartition()(mod) + mod = tvm.tir.transform.Simplify()(mod) + assert tvm.ir.structural_equal(mod["main"], partitioned_concat) + + if __name__ == "__main__": test_basic() test_const_loop() @@ -559,3 +587,4 @@ def test_simple_rfactor(): test_double_splitting_with_indivisible_factors() test_multilevel_splitting_with_indivisble_factors() test_simple_rfactor() + test_explicit_partition_hint() From 0d27ba0cc974c46d746ef5235e89b4edb70cf08c Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Wed, 29 Sep 2021 11:07:49 +0100 Subject: [PATCH 14/20] Fix Google Mock differences between Ubuntu 18.04 and 16.04 (#9141) I thought I got all of these, but turns out I didn't, there's another weirdism in how Google Mock and Google Test are packaged on Ubuntu where the `cmake` command fails due to directories outside of the build root. Double checked logs on Ubuntu 18.04 and Ubuntu 16.04 for this after enabling verbose copying: ```shell $ ./docker/build.sh ci_cpu --net=host ... 'googlemock/libgmock.a' -> '/usr/lib/libgmock.a' 'googlemock/libgmock_main.a' -> '/usr/lib/libgmock_main.a' 'googlemock/gtest/libgtest.a' -> '/usr/lib/libgtest.a' 'googlemock/gtest/libgtest_main.a' -> '/usr/lib/libgtest_main.a' $ ./docker/build.sh ci_i386 --net=host ... 'libgtest.a' -> '/usr/lib/libgtest.a' 'libgtest_main.a' -> '/usr/lib/libgtest_main.a' ... 'libgmock.a' -> '/usr/lib/libgmock.a' 'libgmock_main.a' -> '/usr/lib/libgmock_main.a' ``` --- docker/install/ubuntu_install_core.sh | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/docker/install/ubuntu_install_core.sh b/docker/install/ubuntu_install_core.sh index fb167b92f5c5..f3e97cbf28b0 100755 --- a/docker/install/ubuntu_install_core.sh +++ b/docker/install/ubuntu_install_core.sh @@ -26,6 +26,13 @@ apt-get update && apt-get install -y --no-install-recommends \ libcurl4-openssl-dev libssl-dev libopenblas-dev g++ sudo \ apt-transport-https graphviz pkg-config curl - -cd /usr/src/gtest && cmake CMakeLists.txt && make && cp *.a /usr/lib -cd /usr/src/gmock && cmake CMakeLists.txt && make && cp *.a /usr/lib +if [[ -d /usr/src/googletest ]]; then + # Single package source (Ubuntu 18.04) + # googletest is installed via libgtest-dev + cd /usr/src/googletest && cmake CMakeLists.txt && make && cp -v {googlemock,googlemock/gtest}/*.a /usr/lib +else + # Split source package (Ubuntu 16.04) + # libgtest-dev and google-mock + cd /usr/src/gtest && cmake CMakeLists.txt && make && cp -v *.a /usr/lib + cd /usr/src/gmock && cmake CMakeLists.txt && make && cp -v *.a /usr/lib +fi From f76e141c868b08e17c477e1afe98437117e9e917 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 29 Sep 2021 09:35:56 -0700 Subject: [PATCH 15/20] [Meta Schedule][M3b] Runner (#9111) This PR is part of the meta schedule project (#8473) that adds the asynchronous program runner interface, as well as a reference implementation of RPCRunner. LocalRunner will be implemented with PopenPool executor in a follow-up PR. Co-authored-by: Xiyou Zhou Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng Address comments Co-authored-by: Cody Yu fix lint --- include/tvm/meta_schedule/runner.h | 169 +++++- python/tvm/meta_schedule/__init__.py | 5 +- .../meta_schedule/builder/local_builder.py | 17 +- python/tvm/meta_schedule/runner/__init__.py | 9 +- python/tvm/meta_schedule/runner/config.py | 190 ++++++ python/tvm/meta_schedule/runner/rpc_runner.py | 567 +++++++++++++++++ python/tvm/meta_schedule/runner/runner.py | 111 ++++ python/tvm/meta_schedule/testing.py | 74 +++ python/tvm/meta_schedule/tune_context.py | 4 +- python/tvm/meta_schedule/utils.py | 37 +- src/meta_schedule/runner/runner.cc | 45 +- .../unittest/test_meta_schedule_runner.py | 571 ++++++++++++++++++ 12 files changed, 1776 insertions(+), 23 deletions(-) create mode 100644 python/tvm/meta_schedule/runner/config.py create mode 100644 python/tvm/meta_schedule/runner/rpc_runner.py create mode 100644 python/tvm/meta_schedule/testing.py create mode 100644 tests/python/unittest/test_meta_schedule_runner.py diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index 36d07024559d..a45a4898d64a 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -20,16 +20,53 @@ #define TVM_META_SCHEDULE_RUNNER_H_ #include +#include namespace tvm { namespace meta_schedule { -/*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */ +/*! \brief The runner's input. */ +class RunnerInputNode : public runtime::Object { + public: + /*! \brief The path to the built artifact. */ + String artifact_path; + /*! \brief The type of device. */ + String device_type; + /*! \brief The argument information. */ + Array args_info; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("artifact_path", &artifact_path); + v->Visit("device_type", &device_type); + v->Visit("args_info", &args_info); + } + + static constexpr const char* _type_key = "meta_schedule.RunnerInput"; + TVM_DECLARE_FINAL_OBJECT_INFO(RunnerInputNode, runtime::Object); +}; + +/*! + * \brief Managed reference to RunnerInputNode + * \sa RunnerInputNode + */ +class RunnerInput : public runtime::ObjectRef { + public: + /*! + * \brief Constructor of RunnerInput + * \param artifact_path The path to the built artifact. + * \param device_type The type of device. + * \param args_info The argument information. + */ + TVM_DLL explicit RunnerInput(String artifact_path, String device_type, Array args_info); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerInput, runtime::ObjectRef, RunnerInputNode); +}; + +/*! \brief The runner's output. */ class RunnerResultNode : public runtime::Object { public: - /*! \brief The run time in seconds. If not None, error_msg should be None. */ + /*! \brief The run time in seconds.*/ Optional> run_secs; - /*! \brief The error message, if any. If not None, run_secs should be None. */ + /*! \brief The error message, if any. */ Optional error_msg; void VisitAttrs(tvm::AttrVisitor* v) { @@ -48,14 +85,134 @@ class RunnerResultNode : public runtime::Object { class RunnerResult : public runtime::ObjectRef { public: /*! - * \brief Constructor for RunnerResult. - * \param run_secs The run time in seconds. - * \param error_msg The error message, if any. + * \brief Constructor + * \brief The run time in seconds. + * \brief The error message, if any. */ TVM_DLL explicit RunnerResult(Optional> run_secs, Optional error_msg); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerResult, runtime::ObjectRef, RunnerResultNode); }; +/*! + * \brief A class to asynchronously fetch runner's output. + * \note The API design is consistent with python's concurrent.futures.Future: + * https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future + */ +class RunnerFutureNode : public runtime::Object { + public: + /*! + * \brief The function type to check whether the runner has finished. + * \return Whether the runner's output is ready. + */ + using FDone = runtime::TypedPackedFunc; + /*! + * \brief The function type to fetch runner output if it is ready. + * \return The runner's output. + */ + using FResult = runtime::TypedPackedFunc; + + /*! \brief The packed function to check whether the runner has finished. */ + FDone f_done; + /*! \brief The packed function to fetch runner output if it is ready. */ + FResult f_result; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_done` is not visited + // `f_result` is not visited + } + + /*! + * \brief Check whether the runner has finished. + * \return A boolean indicating whether the runner has finished. + */ + bool Done() const { return f_done(); } + /*! + * \brief Fetch the runner's output if it is ready. + * \return The runner's output. + */ + RunnerResult Result() const { return f_result(); } + + static constexpr const char* _type_key = "meta_schedule.RunnerFuture"; + TVM_DECLARE_FINAL_OBJECT_INFO(RunnerFutureNode, runtime::Object); +}; + +/*! + * \brief Managed reference to RunnerFutureNode + * \sa RunnerFutureNode + */ +class RunnerFuture : public runtime::ObjectRef { + public: + using FDone = RunnerFutureNode::FDone; + using FResult = RunnerFutureNode::FResult; + + /*! + * \brief Constructor of RunnerFuture + * \param f_done The packed function to check whether the runner has finished. + * \param f_result The packed function to fetch runner output if it is ready. + */ + TVM_DLL explicit RunnerFuture(FDone f_done, FResult f_result); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerFuture, runtime::ObjectRef, + RunnerFutureNode); +}; + +/*! \brief The abstract runner interface. */ +class RunnerNode : public runtime::Object { + public: + /*! + * \brief The function type to run the built artifacts and get runner futures. + * \param input The runner's inputs. + * \return The runner futures. + * \sa RunnerFuture + */ + using FRun = runtime::TypedPackedFunc(Array)>; + + /*! \brief Default destructor */ + virtual ~RunnerNode() = default; + + /*! + * \brief Run the built artifact and get runner futures. + * \param runner_inputs The runner's inputs. + * \return The runner futures. + */ + virtual Array Run(Array runner_inputs) = 0; + + static constexpr const char* _type_key = "meta_schedule.Runner"; + TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, runtime::Object); +}; + +/*! + * \brief Managed reference to RunnerNode + * \sa RunnerNode + */ +class Runner : public runtime::ObjectRef { + public: + using FRun = RunnerNode::FRun; + + /*! + * \brief Create a runner with customized build method on the python-side. + * \param f_run The packed function to run the built artifacts and get runner futures. + * \return The runner created. + */ + TVM_DLL static Runner PyRunner(FRun f_run); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Runner, runtime::ObjectRef, RunnerNode); +}; + +/*! \brief An abstract runner with customized build method on the python-side. */ +class PyRunnerNode : public RunnerNode { + public: + /*! \brief The packed function to run the built artifacts and get runner futures. */ + FRun f_run; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_run` is not visited + } + + Array Run(Array runner_inputs) final { return f_run(runner_inputs); } + + static constexpr const char* _type_key = "meta_schedule.PyRunner"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyRunnerNode, RunnerNode); +}; + } // namespace meta_schedule } // namespace tvm diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index c22cc205bf35..2e280ef20ac3 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -16,10 +16,9 @@ # under the License. """Package `tvm.meta_schedule`. The meta schedule infrastructure.""" from . import arg_info -from . import builder from . import database +from . import builder +from . import runner from . import space_generator from . import search_strategy -from . import runner -from .database import TuningRecord from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index cefe5ec50cad..99dfaea56090 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -48,11 +48,20 @@ class LocalBuilder(PyBuilder): Attributes ---------- T_BUILD : typing._GenericAlias - The signature of the build function `f_build`, which is - `Callable[[IRModule, Target], Module]` + The signature of the function `f_build`, which is + + .. code-block:: python + + def default_build(mod: IRModule, target: Target) -> Module: + ... + T_EXPORT : typing._GenericAlias - The signature of the build function `f_export`, which is - `Callable[[Module], str]` + The signature of the function `f_export`, which is + + .. code-block:: python + + def default_export(mod: Module) -> str: + ... Note ---- diff --git a/python/tvm/meta_schedule/runner/__init__.py b/python/tvm/meta_schedule/runner/__init__.py index 65d2ef04e04c..47f4557e1d3a 100644 --- a/python/tvm/meta_schedule/runner/__init__.py +++ b/python/tvm/meta_schedule/runner/__init__.py @@ -14,5 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""meta_schedule.runner""" -from .runner import RunnerResult +""" +The tvm.meta_schedule.runner package. +Meta Schedule runners that runs an artifact either locally or through the RPC interface +""" +from .config import EvaluatorConfig, RPCConfig +from .rpc_runner import RPCRunner +from .runner import PyRunner, Runner, RunnerFuture, RunnerInput, RunnerResult diff --git a/python/tvm/meta_schedule/runner/config.py b/python/tvm/meta_schedule/runner/config.py new file mode 100644 index 000000000000..712766de99c1 --- /dev/null +++ b/python/tvm/meta_schedule/runner/config.py @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Configurations for measurements in the runner""" +import os +from threading import Thread +from typing import NamedTuple, Optional, Union + +from tvm import rpc + + +class EvaluatorConfig(NamedTuple): + """Config Details of Evaluator + + Parameters + ---------- + number: int + The number of runs. + repeat: int + The number of times to repeat in each run. + min_repeat_ms: int + Minimum repeat time in ms. if the execution latency is too short, + increase the number of runs to the given time (in ms) to reduce the measurement error. + enable_cpu_cache_flush: bool + Whether to flush the cache on CPU. + + Note + ---- + The total number of actual executions is 1+number*repeat because we would warm up 1 time before + actual run. The number of runs would be increased if run time is below min_repeat_ms. + """ + + number: int = 3 + repeat: int = 1 + min_repeat_ms: int = 40 + enable_cpu_cache_flush: bool = False + + @staticmethod + def _normalized(config: Optional["EvaluatorConfig"]) -> "EvaluatorConfig": + if config is None: + return EvaluatorConfig() + config = EvaluatorConfig( + number=config.number, + repeat=config.repeat, + min_repeat_ms=config.min_repeat_ms, + enable_cpu_cache_flush=config.enable_cpu_cache_flush, + ) + return config + + +class RPCConfig(NamedTuple): + """RPC configuration + + Parameters + ---------- + tracker_host: str + Host of the RPC Tracker + tracker_port: int + Port of the RPC Tracker + tracker_key: str + Key of the Tracker + session_timeout_sec: float + Timeout of the RPC session + session_priority: int + Priority of the RPC session + """ + + tracker_host: Optional[str] = None + tracker_port: Union[None, int, str] = None + tracker_key: Optional[str] = None + session_priority: int = 1 + session_timeout_sec: int = 10 + + def _sanity_check(self) -> None: + err_str = ( + "RPCConfig.{0} is not provided. Please provide it explicitly," + "or set environment variable {1}" + ) + if self.tracker_host is None: + raise ValueError(err_str.format("tracker_host", "TVM_TRACKER_HOST")) + if self.tracker_port is None: + raise ValueError(err_str.format("tracker_port", "TVM_TRACKER_PORT")) + if self.tracker_key is None: + raise ValueError(err_str.format("tracker_key", "TVM_TRACKER_KEY")) + + @staticmethod + def _normalized(config: Optional["RPCConfig"]) -> "RPCConfig": + if config is None: + config = RPCConfig() + config = RPCConfig( + tracker_host=config.tracker_host or os.environ.get("TVM_TRACKER_HOST", None), + tracker_port=config.tracker_port or os.environ.get("TVM_TRACKER_PORT", None), + tracker_key=config.tracker_key or os.environ.get("TVM_TRACKER_KEY", None), + session_priority=config.session_priority, + session_timeout_sec=config.session_timeout_sec, + ) + config._sanity_check() # pylint: disable=protected-access + return config + + def connect_tracker(self) -> rpc.TrackerSession: + """Connect to the tracker + + Returns + ------- + tracker : TrackerSession + The connected tracker session + """ + tracker: Optional[rpc.TrackerSession] = None + + def _connect(): + nonlocal tracker + tracker = rpc.connect_tracker(self.tracker_host, self.tracker_port) + + t = Thread(target=_connect) + t.start() + t.join(self.session_timeout_sec) + if t.is_alive() or tracker is None: + raise ValueError( + "Unable to connect to the tracker using the following configuration:\n" + f" tracker host: {self.tracker_host}\n" + f" tracker port: {self.tracker_port}\n" + f" timeout (sec): {self.session_timeout_sec}\n" + "Please check the tracker status via the following command:\n" + " python3 -m tvm.exec.query_rpc_tracker " + f"--host {self.tracker_host} --port {self.tracker_port}" + ) + return tracker + + def connect_server(self) -> rpc.RPCSession: + """Connect to the server + + Returns + ------- + session : RPCSession + The connected rpc session + """ + tracker = self.connect_tracker() + session: rpc.RPCSession = tracker.request( + key=self.tracker_key, + priority=self.session_priority, + session_timeout=self.session_timeout_sec, + ) + return session + + def count_num_servers(self, allow_missing=True) -> int: + """Count the number of servers available in the tracker + + Parameters + ---------- + allow_missing : bool + Whether to allow no server to be found. + + Returns + ------- + num_servers : int + The number of servers + """ + tracker = self.connect_tracker() + tracker_summary = tracker.summary() + result: int = 0 + for item in tracker_summary["server_info"]: + _, item_key = item["key"].split(":") + if item_key == self.tracker_key: + result += 1 + if result == 0 and not allow_missing: + raise ValueError( + "Unable to find servers with the specific key using the following configuration:\n" + f" tracker host: {self.tracker_host}\n" + f" tracker port: {self.tracker_port}\n" + f" tracker key: {self.tracker_key}\n" + f" timeout (sec): {self.session_timeout_sec}\n" + "Please check the tracker status via the following command:\n" + " python3 -m tvm.exec.query_rpc_tracker " + f"--host {self.tracker_host} --port {self.tracker_port}\n" + f'and look for key: "{self.tracker_key}"' + ) + return result diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py b/python/tvm/meta_schedule/runner/rpc_runner.py new file mode 100644 index 000000000000..d20e1707fcec --- /dev/null +++ b/python/tvm/meta_schedule/runner/rpc_runner.py @@ -0,0 +1,567 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""RPC Runner""" +import concurrent.futures +from contextlib import contextmanager +import itertools +import os.path as osp +from typing import Any, Callable, Dict, List, Optional, Union + +from tvm.contrib.popen_pool import PopenPoolExecutor +from tvm.rpc import RPCSession +from tvm.runtime import Device, Module, ndarray + +from ..utils import ( + get_global_func_on_rpc_session, + get_global_func_with_default_on_worker, +) +from .config import EvaluatorConfig, RPCConfig +from .runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult + + +class RPCRunnerFuture(RunnerFuture): + """RPC based runner future + + Parameters + ---------- + future: concurrent.futures.Future + The concurrent function to check when the function is done and to return the result. + timeout_sec: float + The timeout in seconds. + """ + + future: concurrent.futures.Future + timeout_sec: float + + def __init__(self, future: concurrent.futures.Future, timeout_sec: float) -> None: + """Constructor + + Parameters + ---------- + future: concurrent.futures.Future + The concurrent function to check when the function is done and to return the result. + timeout_sec: float + The timeout in seconds. + """ + super().__init__() + self.future = future + self.timeout_sec = timeout_sec + + def done(self) -> bool: + return self.future.done() + + def result(self) -> RunnerResult: + try: + run_secs: List[float] = self.future.result() + except TimeoutError as exception: + return RunnerResult( + None, + error_msg=f"RPCRunner: Timeout, killed after {self.timeout_sec} seconds", + ) + except Exception as exception: # pylint: disable=broad-except + return RunnerResult( + None, + error_msg="RPCRunner: An exception occurred\n" + str(exception), + ) + return RunnerResult(run_secs, None) + + +T_ARG_INFO_JSON_OBJ = List[Any] # pylint: disable=invalid-name +T_ARG_INFO_JSON_OBJ_LIST = List[T_ARG_INFO_JSON_OBJ] # pylint: disable=invalid-name +T_ARGUMENT = Any # pylint: disable=invalid-name +T_ARGUMENT_LIST = List[T_ARGUMENT] # pylint: disable=invalid-name + + +class RPCRunner(PyRunner): + """RPC based runner + + Parameters + ---------- + rpc_config: RPCConfig + The rpc configuration. + evaluator_config: EvaluatorConfig + The evaluator configuration. + cooldown_sec: float + The cooldown in seconds. TODO(@junrushao1994,@zxybazh): This is not used yet. + alloc_repeat: int + The number of times to repeat the allocation. + f_create_session: Optional[str, Callable] + The function name to create the session or the function itself. + f_upload_module: Optional[str, Callable] + The function name to upload the module or the function itself. + f_alloc_argument: Optional[str, Callable] + The function name to allocate the arguments or the function itself. + f_run_evaluator: Optional[str, Callable] + The function name to run the evaluator or the function itself. + f_cleanup: Optional[str, Callable] + The function name to cleanup the session or the function itself. + pool: PopenPoolExecutor + The popen pool executor. + + Attributes + ---------- + T_CREATE_SESSION : typing._GenericAlias + The signature of the function `f_create_session`, which is: + + .. code-block:: python + + def default_create_session(rpc_config: RPCConfig) -> RPCSession: + ... + + T_UPLOAD_MODULE : typing._GenericAlias + The signature of the function `f_upload_module`, which is: + + .. code-block:: python + + def default_upload_module( + session: RPCSession, + local_path: str, + remote_path: str, + ) -> Module: + ... + + T_ALLOC_ARGUMENT : typing._GenericAlias + The signature of the function `f_alloc_argument`, which is: + + .. code-block:: python + + def default_alloc_argument( + session: RPCSession, + device: Device, + args_info: T_ARG_INFO_JSON_OBJ_LIST, + alloc_repeat: int, + ) -> List[T_ARGUMENT_LIST]: + ... + + T_RUN_EVALUATOR : typing._GenericAlias + The signature of the function `f_run_evaluator`, which is: + + .. code-block:: python + + def default_run_evaluator( + session: RPCSession, + rt_mod: Module, + device: Device, + evaluator_config: EvaluatorConfig, + repeated_args: List[T_ARGUMENT_LIST], + ) -> List[float]: + ... + + T_CLEANUP : typing._GenericAlias + The signature of the function `f_cleanup`, which is: + + .. code-block:: python + + def default_cleanup( + session: Optional[RPCSession], + remote_path: Optional[str], + ) -> None: + ... + """ + + T_CREATE_SESSION = Callable[ + [RPCConfig], # The RPC configuration + RPCSession, # The RPC Session + ] + T_UPLOAD_MODULE = Callable[ + [ + RPCSession, # The RPC Session + str, # local path to the artifact + str, # remote path to the artifact + ], + Module, # the Module opened on the remote + ] + T_ALLOC_ARGUMENT = Callable[ + [ + RPCSession, # The RPC Session + Device, # The device on the remote + T_ARG_INFO_JSON_OBJ_LIST, # The metadata information of the arguments to be allocated + int, # The number of repeated allocations to be done + ], + List[T_ARGUMENT_LIST], # A list of argument lists + ] + T_RUN_EVALUATOR = Callable[ + [ + RPCSession, # The RPC Session + Module, # The Module opened on the remote + Device, # The device on the remote + EvaluatorConfig, # The evaluator configuration + List[T_ARGUMENT_LIST], # A list of argument lists + ], + List[float], # A list of running time + ] + T_CLEANUP = Callable[ + [ + Optional[RPCSession], # The RPC Session to be cleaned up + Optional[str], # remote path to the artifact + ], + None, + ] + + rpc_config: RPCConfig + evaluator_config: EvaluatorConfig + cooldown_sec: float + alloc_repeat: int + + f_create_session: Union[T_CREATE_SESSION, str, None] + f_upload_module: Union[T_UPLOAD_MODULE, str, None] + f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None] + f_run_evaluator: Union[T_RUN_EVALUATOR, str, None] + f_cleanup: Union[T_CLEANUP, str, None] + + pool: PopenPoolExecutor + + def __init__( + self, + rpc_config: Optional[RPCConfig] = None, + evaluator_config: Optional[EvaluatorConfig] = None, + cooldown_sec: float = 0.0, + alloc_repeat: int = 1, + f_create_session: Union[T_CREATE_SESSION, str, None] = None, + f_upload_module: Union[T_UPLOAD_MODULE, str, None] = None, + f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None] = None, + f_run_evaluator: Union[T_RUN_EVALUATOR, str, None] = None, + f_cleanup: Union[T_CLEANUP, str, None] = None, + max_connections: Optional[int] = None, + initializer: Optional[Callable[[], None]] = None, + ) -> None: + """Constructor + + Parameters + ---------- + rpc_config: RPCConfig + The rpc configuration. + evaluator_config: EvaluatorConfig + The evaluator configuration. + cooldown_sec: float + The cooldown in seconds. + alloc_repeat: int + The number of times to random fill the allocation. + f_create_session: Union[T_CREATE_SESSION, str, None] + The function name to create the session or the function itself. + f_upload_module: Union[T_UPLOAD_MODULE, str, None] + The function name to upload the module or the function itself. + f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None] + The function name to allocate the arguments or the function itself. + f_run_evaluator: Union[T_RUN_EVALUATOR, str, None] + The function name to run the evaluator or the function itself. + f_cleanup: Union[T_CLEANUP, str, None] + The function name to cleanup the session or the function itself. + max_connections: Optional[int] + The maximum number of connections. + initializer: Optional[Callable[[], None]] + The initializer function. + """ + super().__init__() + self.rpc_config = RPCConfig._normalized(rpc_config) + self.evaluator_config = EvaluatorConfig._normalized(evaluator_config) + self.cooldown_sec = cooldown_sec + self.alloc_repeat = alloc_repeat + self.f_create_session = f_create_session + self.f_upload_module = f_upload_module + self.f_alloc_argument = f_alloc_argument + self.f_run_evaluator = f_run_evaluator + self.f_cleanup = f_cleanup + + num_servers = self.rpc_config.count_num_servers(allow_missing=False) + if max_connections is None: + max_connections = num_servers + else: + max_connections = min(max_connections, num_servers) + + self.pool = PopenPoolExecutor( + max_workers=max_connections, + timeout=rpc_config.session_timeout_sec, + initializer=initializer, + ) + self._sanity_check() + + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + results: List[RunnerFuture] = [] + for runner_input in runner_inputs: + future = RPCRunnerFuture( + future=self.pool.submit( + RPCRunner._worker_func, + self.f_create_session, + self.f_upload_module, + self.f_alloc_argument, + self.f_run_evaluator, + self.f_cleanup, + self.rpc_config, + self.evaluator_config, + self.alloc_repeat, + str(runner_input.artifact_path), + str(runner_input.device_type), + tuple(arg_info.as_json() for arg_info in runner_input.args_info), + ), + timeout_sec=self.rpc_config.session_timeout_sec, + ) + results.append(future) + return results + + def _sanity_check(self) -> None: + def _check( + f_create_session, + f_upload_module, + f_alloc_argument, + f_run_evaluator, + f_cleanup, + ) -> None: + get_global_func_with_default_on_worker(name=f_create_session, default=None) + get_global_func_with_default_on_worker(name=f_upload_module, default=None) + get_global_func_with_default_on_worker(name=f_alloc_argument, default=None) + get_global_func_with_default_on_worker(name=f_run_evaluator, default=None) + get_global_func_with_default_on_worker(name=f_cleanup, default=None) + + value = self.pool.submit( + _check, + self.f_create_session, + self.f_upload_module, + self.f_alloc_argument, + self.f_run_evaluator, + self.f_cleanup, + ) + value.result() + + @staticmethod + def _worker_func( + _f_create_session: Union[T_CREATE_SESSION, str, None], + _f_upload_module: Union[T_UPLOAD_MODULE, str, None], + _f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None], + _f_run_evaluator: Union[T_RUN_EVALUATOR, str, None], + _f_cleanup: Union[T_CLEANUP, str, None], + rpc_config: RPCConfig, + evaluator_config: EvaluatorConfig, + alloc_repeat: int, + artifact_path: str, + device_type: str, + args_info: T_ARG_INFO_JSON_OBJ_LIST, + ) -> List[float]: + # Step 0. Get the registered functions + f_create_session: RPCRunner.T_CREATE_SESSION = get_global_func_with_default_on_worker( + _f_create_session, default_create_session + ) + f_upload_module: RPCRunner.T_UPLOAD_MODULE = get_global_func_with_default_on_worker( + _f_upload_module, default_upload_module + ) + f_alloc_argument: RPCRunner.T_ALLOC_ARGUMENT = get_global_func_with_default_on_worker( + _f_alloc_argument, default_alloc_argument + ) + f_run_evaluator: RPCRunner.T_RUN_EVALUATOR = get_global_func_with_default_on_worker( + _f_run_evaluator, default_run_evaluator + ) + f_cleanup: RPCRunner.T_CLEANUP = get_global_func_with_default_on_worker( + _f_cleanup, default_cleanup + ) + # Managed resources + session: Optional[RPCSession] = None + remote_path: Optional[str] = None + + @contextmanager + def resource_handler(): + try: + yield + finally: + # Step 5. Clean up + f_cleanup(session, remote_path) + + with resource_handler(): + # Step 1. Create session + session = f_create_session(rpc_config) + device = session.device(dev_type=device_type, dev_id=0) + # Step 2. Upload the module + _, remote_path = osp.split(artifact_path) + local_path: str = artifact_path + rt_mod: Module = f_upload_module(session, local_path, remote_path) + # Step 3: Allocate input arguments + repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument( + session, + device, + args_info, + alloc_repeat, + ) + # Step 4: Run time_evaluator + costs: List[float] = f_run_evaluator( + session, + rt_mod, + device, + evaluator_config, + repeated_args, + ) + return costs + + +def default_create_session(rpc_config: RPCConfig) -> RPCSession: + """Default function to create the session + + Parameters + ---------- + rpc_config : RPCConfig + The configuration of the RPC session + + Returns + ------- + session : RPCSession + The created rpc session + """ + return rpc_config.connect_server() + + +def default_upload_module( + session: RPCSession, + local_path: str, + remote_path: str, +) -> Module: + """Default function to upload the module + + Parameters + ---------- + session: RPCSession + The session to upload the module + local_path: str + The local path of the module + remote_path: str + The remote path to place the module + + Returns + ------- + rt_mod : Module + The runtime module + """ + session.upload(local_path, remote_path) + rt_mod: Module = session.load_module(remote_path) + return rt_mod + + +def default_alloc_argument( + session: RPCSession, + device: Device, + args_info: T_ARG_INFO_JSON_OBJ_LIST, + alloc_repeat: int, +) -> List[T_ARGUMENT_LIST]: + """Default function to allocate the arguments + + Parameters + ---------- + session: RPCSession + The session to allocate the arguments + device: Device + The device to allocate the arguments + alloc_repeat: int + The number of times to repeat the allocation + args_info: PyArgsInfo + The arguments info + + Returns + ------- + repeated_args: List[Args] + The allocation args + """ + f_random_fill = get_global_func_on_rpc_session( + session, + "tvm.contrib.random.random_fill", + "Please make sure 'USE_RANDOM' is turned ON in the config.cmake on the RPC server.", + ) + + def alloc_tensor(_, dtype, shape) -> ndarray.NDArray: + arg = ndarray.empty(shape=shape, dtype=dtype, device=device) + f_random_fill(arg) + return arg + + def alloc_fail(*arg_info) -> None: + raise NotImplementedError(arg_info) + + dispatcher: Dict[Any, Callable] = { + "TENSOR": alloc_tensor, + None: alloc_fail, + } + + repeated_args: List[T_ARGUMENT_LIST] = [] + for _ in range(alloc_repeat): + args: T_ARGUMENT_LIST = [] + arg_info: T_ARG_INFO_JSON_OBJ + for arg_info in args_info: + arg_type = arg_info[0] + arg: Any = dispatcher.get(arg_type, None)(*arg_info) + args.append(arg) + repeated_args.append(args) + return repeated_args + + +def default_run_evaluator( + session: RPCSession, # pylint: disable=unused-argument + rt_mod: Module, + device: Device, + evaluator_config: EvaluatorConfig, + repeated_args: List[T_ARGUMENT_LIST], +) -> List[float]: + """Default function to run the evaluator + + Parameters + ---------- + session: RPCSession + The session to run the evaluator + rt_mod: Module + The runtime module + device: Device + The device to run the evaluator + evaluator_config: EvaluatorConfig + The evaluator config + repeated_args: List[Args] + The repeated arguments + + Returns + ------- + costs: List[float] + The evaluator results + """ + evaluator = rt_mod.time_evaluator( + func_name=rt_mod.entry_name, + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + device.sync() + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + return costs + + +def default_cleanup( + session: Optional[RPCSession], + remote_path: Optional[str], +) -> None: + """Default function to clean up the session + + Parameters + ---------- + session: RPCSession + The session to clean up + remote_path: str + The remote path to clean up + """ + if session is not None and remote_path is not None: + session.remove(remote_path) + session.remove(remote_path + ".so") + session.remove("") diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py index b756c6e6b011..9f7be8ea4af4 100644 --- a/python/tvm/meta_schedule/runner/runner.py +++ b/python/tvm/meta_schedule/runner/runner.py @@ -21,6 +21,50 @@ from tvm.runtime import Object from .. import _ffi_api +from ..arg_info import ArgInfo + + +@register_object("meta_schedule.RunnerInput") +class RunnerInput(Object): + """The runner's input + + Parameters + ---------- + artifact_path : str + The path to the built artifact. + device_type : str + The device type. + args_info : List[ArgInfo] + The argument information. + """ + + artifact_path: str + device_type: str + args_info: List[ArgInfo] + + def __init__( + self, + artifact_path: str, + device_type: str, + args_info: List[ArgInfo], + ) -> None: + """Constructor + + Parameters + ---------- + artifact_path : str + The path to the built artifact. + device_type : str + The device type. + args_info : List[ArgInfo] + The argument information. + """ + self.__init_handle_by_constructor__( + _ffi_api.RunnerInput, # type: ignore # pylint: disable=no-member + artifact_path, + device_type, + args_info, + ) @register_object("meta_schedule.RunnerResult") @@ -57,3 +101,70 @@ def __init__( run_secs, error_msg, ) + + +@register_object("meta_schedule.RunnerFuture") +class RunnerFuture(Object): + """A class to fetch asynchronous runner's output.""" + + def __init__(self) -> None: + """Constructor""" + + def f_done(): + return self.done() + + def f_result(): + return self.result() + + self.__init_handle_by_constructor__( + _ffi_api.RunnerFuture, # type: ignore # pylint: disable=no-member + f_done, + f_result, + ) + + def done(self) -> bool: + """Check whether the runner has finished.""" + raise NotImplementedError + + def result(self) -> RunnerResult: + """Fetch the runner's output if it is ready.""" + raise NotImplementedError + + +@register_object("meta_schedule.Runner") +class Runner(Object): + """The abstract runner interface""" + + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + """Run the built artifact and get runner futures. + + Parameters + ---------- + runner_inputs : List[RunnerInput] + The inputs to the runner. + + Returns + ------- + runner_futures: List[RunnerFuture] + The runner futures. + """ + return _ffi_api.RunnerRun(self, runner_inputs) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.PyRunner") +class PyRunner(Runner): + """An abstract runner with customized build method on the python-side.""" + + def __init__(self) -> None: + """Constructor""" + + def f_run(runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + return self.run(runner_inputs) + + self.__init_handle_by_constructor__( + _ffi_api.RunnerPyRunner, # type: ignore # pylint: disable=no-member + f_run, + ) + + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + raise NotImplementedError diff --git a/python/tvm/meta_schedule/testing.py b/python/tvm/meta_schedule/testing.py new file mode 100644 index 000000000000..4caaeb7553cc --- /dev/null +++ b/python/tvm/meta_schedule/testing.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilities in meta schedule""" +import time + +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server + + +class LocalRPC: + """A pair of RPC tracker/server running locally + + Parameters + ---------- + tracker_host : str + The host URL of the tracker + tracker_port : int + The port of the tracker + tracker_key: str + The key used in the tracker to refer to a worker + """ + + tracker_host: str + tracker_port: int + tracker_key: str + + def __init__( + self, + tracker_key: str = "key", + silent: bool = False, + no_fork: bool = False, + ) -> None: + self.tracker = Tracker( + silent=silent, + port=9190, + port_end=12345, + ) + time.sleep(0.5) + self.server = Server( + host="0.0.0.0", + is_proxy=False, + tracker_addr=(self.tracker.host, self.tracker.port), + key=tracker_key, + silent=silent, + no_fork=no_fork, + port=9190, + port_end=12345, + ) + self.tracker_host = self.tracker.host + self.tracker_port = self.tracker.port + self.tracker_key = tracker_key + + def __enter__(self): + return self + + def __exit__(self, _type, _value, _traceback): + if hasattr(self, "server"): + del self.server + if hasattr(self, "tracker"): + del self.tracker diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 4c83b9afa289..9c41b4d575da 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -19,10 +19,10 @@ from typing import Optional, TYPE_CHECKING from tvm import IRModule +from tvm._ffi import register_object +from tvm.meta_schedule.utils import cpu_count from tvm.runtime import Object from tvm.target import Target -from tvm.meta_schedule.utils import cpu_count -from tvm._ffi import register_object from . import _ffi_api diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index e710b0ed06f3..5f536994a9fd 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -18,14 +18,14 @@ import json import os import shutil -from typing import Any, Callable, List, Union +from typing import Any, Callable, List, Optional, Union import psutil - from tvm._ffi import get_global_func, register_func from tvm.error import TVMError from tvm.ir import Array, Map -from tvm.runtime import String +from tvm.rpc import RPCSession +from tvm.runtime import PackedFunc, String from tvm.tir import FloatImm, IntImm @@ -95,6 +95,37 @@ def get_global_func_with_default_on_worker( ) from error +def get_global_func_on_rpc_session( + session: RPCSession, + name: str, + extra_error_msg: Optional[str] = None, +) -> PackedFunc: + """Get a PackedFunc from the global registry from an RPCSession. + + Parameters + ---------- + session : RPCSession + The RPCSession to be retrieved from + name : str + The name of the PackedFunc + extra_error_msg : Optional[str] + Extra information to provide in the error message + + Returns + ------- + result : PackedFunc + The result + """ + try: + result = session.get_function(name) + except AttributeError as error: + error_msg = f'Unable to find function "{name}" on the remote RPC server.' + if extra_error_msg: + error_msg = f"{error_msg} {extra_error_msg}" + raise AttributeError(error_msg) from error + return result + + @register_func("meta_schedule.remove_build_dir") def remove_build_dir(artifact_path: str) -> None: """Clean up the build directory""" diff --git a/src/meta_schedule/runner/runner.cc b/src/meta_schedule/runner/runner.cc index 8f509bdd7b84..800a76f21e65 100644 --- a/src/meta_schedule/runner/runner.cc +++ b/src/meta_schedule/runner/runner.cc @@ -16,13 +16,19 @@ * specific language governing permissions and limitations * under the License. */ -#include - #include "../utils.h" namespace tvm { namespace meta_schedule { +RunnerInput::RunnerInput(String artifact_path, String device_type, Array args_info) { + ObjectPtr n = make_object(); + n->artifact_path = artifact_path; + n->device_type = device_type; + n->args_info = args_info; + this->data_ = n; +} + RunnerResult::RunnerResult(Optional> run_secs, Optional error_msg) { ObjectPtr n = make_object(); n->run_secs = run_secs; @@ -30,12 +36,45 @@ RunnerResult::RunnerResult(Optional> run_secs, Optional this->data_ = n; } -TVM_REGISTER_NODE_TYPE(RunnerResultNode); +RunnerFuture::RunnerFuture(RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) { + ObjectPtr n = make_object(); + n->f_done = f_done; + n->f_result = f_result; + this->data_ = n; +} +Runner Runner::PyRunner(Runner::FRun f_run) { + ObjectPtr n = make_object(); + n->f_run = f_run; + return Runner(n); +} + +/******** FFI ********/ + +TVM_REGISTER_NODE_TYPE(RunnerInputNode); +TVM_REGISTER_NODE_TYPE(RunnerResultNode); +TVM_REGISTER_NODE_TYPE(RunnerFutureNode); +TVM_REGISTER_OBJECT_TYPE(RunnerNode); +TVM_REGISTER_NODE_TYPE(PyRunnerNode); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerInput") + .set_body_typed([](String artifact_path, String device_type, + Array args_info) -> RunnerInput { + return RunnerInput(artifact_path, device_type, args_info); + }); TVM_REGISTER_GLOBAL("meta_schedule.RunnerResult") .set_body_typed([](Array run_secs, Optional error_msg) -> RunnerResult { return RunnerResult(run_secs, error_msg); }); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerFuture") + .set_body_typed([](RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) -> RunnerFuture { + return RunnerFuture(f_done, f_result); + }); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerFutureDone") + .set_body_method(&RunnerFutureNode::Done); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerFutureResult") + .set_body_method(&RunnerFutureNode::Result); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerRun").set_body_method(&RunnerNode::Run); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerPyRunner").set_body_typed(Runner::PyRunner); } // namespace meta_schedule } // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_runner.py b/tests/python/unittest/test_meta_schedule_runner.py new file mode 100644 index 000000000000..3c8aee0c6d58 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_runner.py @@ -0,0 +1,571 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule Runner """ + +import itertools +import sys +import time +from typing import Any, List + +import numpy as np +import pytest + +import tvm +from tvm import tir +from tvm._ffi import register_func +from tvm.meta_schedule.arg_info import TensorInfo +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder +from tvm.meta_schedule.runner import ( + EvaluatorConfig, + PyRunner, + RPCConfig, + RPCRunner, + RunnerFuture, + RunnerInput, +) +from tvm.meta_schedule.runner.rpc_runner import ( + default_alloc_argument as rpc_default_alloc_argument, +) +from tvm.meta_schedule.testing import LocalRPC +from tvm.meta_schedule.utils import get_global_func_with_default_on_worker +from tvm.rpc import RPCSession +from tvm.runtime import Device, Module +from tvm.script import ty +from tvm.target import Target +import tvm.testing +from tvm.tir import FloatImm + +MATMUL_N = 16 +MATMUL_M = 32 + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking + + +@tvm.script.tir +class MatmulModule: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + with tir.block([16, 16, tir.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.tir +class MatmulReluModule: + def main(a: ty.handle, b: ty.handle, d: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (16, 16), "float32") + D = tir.match_buffer(d, (16, 16), "float32") + C = tir.alloc_buffer((16, 16), "float32") + with tir.block([16, 16, tir.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + with tir.block([16, 16], "relu") as [vi, vj]: + D[vi, vj] = tir.max(C[vi, vj], 0.0) + + +@tvm.script.tir +class BatchMatmulModule: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, [16, 32, 32]) + B = tir.match_buffer(b, [16, 32, 32]) + C = tir.match_buffer(c, [16, 32, 32]) + with tir.block([16, 32, 32, tir.reduce_axis(0, 32)], "update") as [vn, vi, vj, vk]: + with tir.init(): + C[vn, vi, vj] = 0.0 + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + + +@tvm.script.tir +class AddModule: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, [32], "float32") + B = tir.match_buffer(b, [32], "float32") + C = tir.match_buffer(c, [32], "float32") + with tir.block([32], "add") as [vi]: + C[vi] = A[vi] + B[vi] + + +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring + + +def _clean_build(artifact_path: str) -> None: + f_clean_build = get_global_func_with_default_on_worker("meta_schedule.remove_build_dir", None) + if f_clean_build is not None: + f_clean_build(artifact_path) + else: + raise RuntimeError("Unable to find remove_build_dir function.") + + +def test_meta_schedule_rpc_single_run(): + """Test meta schedule rpc runner for a single run""" + # Build the module + mod = MatmulModule() + builder = LocalBuilder() + (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) + assert builder_result.artifact_path is not None + assert builder_result.error_msg is None + + runner_input = RunnerInput( + builder_result.artifact_path, + "llvm", + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + ) + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner(rpc_config, evaluator_config) + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + assert runner_result.error_msg is None + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + _clean_build(builder_result.artifact_path) + + +def test_meta_schedule_rpc_multiple_runs(): + """Test meta schedule rpc runner for multiple runs""" + # Build the module + mods = [ + MatmulModule(), + MatmulReluModule(), + BatchMatmulModule(), + ] + builder = LocalBuilder() + builder_inputs = [BuilderInput(mod, Target("llvm")) for mod in mods] + builder_results = builder.build(builder_inputs) + for builder_result in builder_results: + assert builder_result.artifact_path is not None + assert builder_result.error_msg is None + + args_infos = [ + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + [ + TensorInfo("float32", [16, MATMUL_M, MATMUL_M]), + TensorInfo("float32", [16, MATMUL_M, MATMUL_M]), + TensorInfo("float32", [16, MATMUL_M, MATMUL_M]), + ], + ] + + runner_inputs = [ + RunnerInput(builder_results[i].artifact_path, "llvm", args_infos[i]) + for i in range(len(mods)) + ] + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner(rpc_config, evaluator_config) + # Run the module + runner_futures = runner.run(runner_inputs) + runner_results = [runner_future.result() for runner_future in runner_futures] + + for runner_result in runner_results: + assert runner_result.error_msg is None + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + + for builder_result in builder_results: + _clean_build(builder_result.artifact_path) + + +def test_meta_schedule_py_runner(): + """Test meta schedule PyRunner""" + + class TestRunner(PyRunner): + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + raise ValueError("TestRunner") + + runner = TestRunner() + with pytest.raises(ValueError, match="TestRunner"): + runner.run([]) + + +def test_meta_schedule_rpc_runner_time_out(): + """Test meta schedule RPC Runner time out""" + + def initializer(): + @register_func("meta_schedule.runner.test_time_out") + def timeout_session_creator( # pylint: disable=unused-variable + rpc_config: RPCConfig, # pylint: disable=unused-argument + ) -> RPCSession: + time.sleep(2) + + runner_input = RunnerInput( + "test", + "llvm", + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + ) + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=1, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner( + rpc_config, + evaluator_config, + initializer=initializer, + f_create_session="meta_schedule.runner.test_time_out", + ) + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + + assert runner_result.error_msg is not None and runner_result.error_msg.startswith( + "RPCRunner: Timeout, killed after" + ) + assert runner_result.run_secs is None + + +def test_meta_schedule_rpc_runner_exception(): + """Test meta schedule RPC Runner exception""" + + def initializer(): + @register_func("meta_schedule.runner.test_exception") + def exception_session_creator( # pylint: disable=unused-variable + rpc_config: RPCConfig, # pylint: disable=unused-argument + ) -> RPCSession: + raise Exception("Test") + + runner_input = RunnerInput( + "test", + "llvm", + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + ) + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner( + rpc_config, + evaluator_config, + initializer=initializer, + f_create_session="meta_schedule.runner.test_exception", + ) + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + + assert runner_result.error_msg is not None and runner_result.error_msg.startswith( + "RPCRunner: An exception occurred\n" + ) + assert runner_result.run_secs is None + + +def test_meta_schedule_runner_matmul_test(): + """Test meta schedule runner with add module""" + + def _check_correct_matmul( + args_before: List[np.ndarray], + args_after: List[np.ndarray], + ) -> None: + a_before, b_before, c_before = args_before + a_after, b_after, c_after = args_after + c_before = np.matmul(a_before, b_before) + assert (a_before == a_after).all() + assert (b_before == b_after).all() + tvm.testing.assert_allclose(c_before, c_after, rtol=1e-5) + + def test_alloc_argument( + session: RPCSession, + device: Device, + args_info: Any, + alloc_repeat: int, + ) -> List[Any]: + global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name + repeated_args_before = [] # type: ignore + repeated_args = rpc_default_alloc_argument(session, device, args_info, alloc_repeat) + for args in repeated_args: + repeated_args_before.append([arg.numpy() for arg in args]) # type: ignore + return repeated_args + + def test_run_evaluator( + session: RPCSession, # pylint: disable=unused-argument + rt_mod: Module, + device: Device, + evaluator_config: EvaluatorConfig, + repeated_args: List[Any], + ) -> List[float]: + global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name + repeated_args_after = [] + evaluator = rt_mod.time_evaluator( + func_name=rt_mod.entry_name, + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + device.sync() + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + repeated_args_after.append([arg.numpy() for arg in args]) + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + for args_before, args_after in zip( + repeated_args_before, # type: ignore + repeated_args_after, + ): + _check_correct_matmul(args_before, args_after) + del repeated_args_before # type: ignore + return costs + + # Build the module + mod = MatmulModule() + builder = LocalBuilder() + (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) + assert builder_result.artifact_path is not None + assert builder_result.error_msg is None + + runner_input = RunnerInput( + builder_result.artifact_path, + "llvm", + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + ) + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner( + rpc_config, + evaluator_config, + f_alloc_argument=test_alloc_argument, + f_run_evaluator=test_run_evaluator, + ) + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + assert runner_result.error_msg is None + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + _clean_build(builder_result.artifact_path) + + +def test_meta_schedule_runner_add_test(): + """Test meta schedule runner with add module""" + + def _check_correct_add(args_before: List[np.ndarray], args_after: List[np.ndarray]) -> None: + a_before, b_before, c_before = args_before + a_after, b_after, c_after = args_after + c_before = a_before + b_before + assert (a_before == a_after).all() + assert (b_before == b_after).all() + assert (c_before == c_after).all() + + def test_alloc_argument( + session: RPCSession, + device: Device, + args_info: Any, + alloc_repeat: int, + ) -> List[Any]: + global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name + repeated_args_before = [] # type: ignore + repeated_args = rpc_default_alloc_argument( + session, + device, + args_info, + alloc_repeat, + ) + for args in repeated_args: + repeated_args_before.append([arg.numpy() for arg in args]) # type: ignore + return repeated_args + + def test_run_evaluator( + session: RPCSession, # pylint: disable=unused-argument + rt_mod: Module, + device: Device, + evaluator_config: EvaluatorConfig, + repeated_args: List[Any], + ) -> List[float]: + global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name + repeated_args_after = [] + evaluator = rt_mod.time_evaluator( + func_name=rt_mod.entry_name, + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + device.sync() + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + repeated_args_after.append([arg.numpy() for arg in args]) + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + for args_before, args_after in zip( + repeated_args_before, # type: ignore + repeated_args_after, + ): + _check_correct_add(args_before, args_after) + del repeated_args_before # type: ignore + return costs + + # Build the module + mod = AddModule() + builder = LocalBuilder() + (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) + assert builder_result.artifact_path is not None + assert builder_result.error_msg is None + + runner_input = RunnerInput( + builder_result.artifact_path, + "llvm", + [ + TensorInfo("float32", [MATMUL_M]), + TensorInfo("float32", [MATMUL_M]), + TensorInfo("float32", [MATMUL_M]), + ], + ) + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner( + rpc_config, + evaluator_config, + f_alloc_argument=test_alloc_argument, + f_run_evaluator=test_run_evaluator, + ) + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + assert runner_result.error_msg is None + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + _clean_build(builder_result.artifact_path) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 7a08ae4dcad068c003cba4f2299def638c98450a Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Wed, 29 Sep 2021 17:36:06 +0100 Subject: [PATCH 16/20] [CI] Split Integration tests out of first phase of pipeline (#9128) * [CI] Split Integration tests out of first phase of pipeline I took a look at the time taken by each stage in the Jenkins pipeline and what comprises the 6 hour CI build time. CPU Integration tests took `65` minutes of the `100` minutes of `Build: CPU`. By adding `python3: CPU` with just those Integration tests, it lines up with `python3: GPU` and `python3: i386` which both take a similar amount of time and takes roughly 60 minutes off the overall run time. Numbers copied from sample successful run (final time approx: 358 minutes): |Phase|ID |Job |Minutes |Start| |-----|-----------------------------|------|---------------------------------------------|-----| |0 |0 |Sanity|3 |0 | |1 |0 |BUILD: arm|2 |3 | |1 |1 |BUILD: i386|33 |3 | |1 |2 |BUILD: CPU|100 |3 | |1 |3 |BUILD: GPU|25 |3 | |1 |4 |BUILD: QEMU|6 |3 | |1 |5 |BUILD: WASM|2 |3 | |2 |0 |java: GPU|1 |103 | |2 |1 |python3: GPU|66 |103 | |2 |2 |python3: arm|22 |103 | |2 |3 |python3: i386|70 |103 | |3 |0 |docs: GPU|3 |173 | |3 |1 |frontend: CPU|40 |173 | |3 |2 |frontend: GPU|185 |173 | |3 |3 |topi: GPU|110 |173 | | | | | | | Numbers predicted after change (final time approx: 293 minutes): |Phase|ID |Job |Minutes |Start| |-----|-----------------------------|------|---------------------------------------------|-----| |0 |0 |Sanity|3 |0 | |1 |0 |BUILD: arm|2 |3 | |1 |1 |BUILD: i386|33 |3 | |1 |2 |BUILD: CPU|35 |3 | |1 |3 |BUILD: GPU|25 |3 | |1 |4 |BUILD: QEMU|6 |3 | |1 |5 |BUILD: WASM|2 |3 | |2 |0 |java: GPU|1 |38 | |2 |1 |python3: GPU|66 |38 | |2 |2 |python3: arm|22 |38 | |2 |3 |python3: i386|70 |38 | |2 |4 |python3: CPU|60 |38 | |3 |0 |docs: GPU|3 |108 | |3 |1 |frontend: CPU|40 |108 | |3 |2 |frontend: GPU|185 |108 | |3 |3 |topi: GPU|110 |108 | * Fix typo in ci_cpu commands --- Jenkinsfile | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index b2852955323f..3a96fbee061d 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -224,7 +224,6 @@ stage('Build') { timeout(time: max_time, unit: 'MINUTES') { sh "${docker_run} ${ci_cpu} ./tests/scripts/task_ci_setup.sh" sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_unittest.sh" - sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_integration.sh" sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_fsim.sh" sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_tsim.sh" // sh "${docker_run} ${ci_cpu} ./tests/scripts/task_golang.sh" @@ -300,6 +299,19 @@ stage('Unit Test') { } } }, + 'python3: CPU': { + node('CPU') { + ws(per_exec_ws("tvm/ut-python-cpu")) { + init_git() + unpack_lib('cpu', tvm_multilib_tsim) + timeout(time: max_time, unit: 'MINUTES') { + sh "${docker_run} ${ci_cpu} ./tests/scripts/task_ci_setup.sh" + sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_integration.sh" + junit "build/pytest-results/*.xml" + } + } + } + }, 'python3: i386': { node('CPU') { ws(per_exec_ws("tvm/ut-python-i386")) { From 86ce111edb86858a12b55478eed9625461f0fa05 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 29 Sep 2021 17:38:25 +0100 Subject: [PATCH 17/20] Arm(R) Ethos(TM)-U NPU codegen integration (#8849) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit integrates the codegen for Arm® Ethos™-U. * Adding Conv2D tests and a mobilenet_v1 conv2d offload test. Co-authored-by: Grant Watson Co-authored-by: Leandro Nunes Co-authored-by: Christopher Sidebottom Co-authored-by: Matthew Barret Co-authored-by: Grant Watson Co-authored-by: Leandro Nunes Co-authored-by: Christopher Sidebottom Co-authored-by: Matthew Barret --- .../relay/backend/contrib/ethosu/__init__.py | 1 + .../relay/backend/contrib/ethosu/codegen.py | 83 +++++ .../relay/backend/contrib/ethosu/legalize.py | 6 + .../tvm/relay/backend/contrib/ethosu/util.py | 12 + src/relay/backend/aot_executor_codegen.cc | 1 - .../backend/contrib/ethosu/source_module.cc | 320 ++++++++++++++++++ tests/python/contrib/test_ethosu/__init__.py | 17 + tests/python/contrib/test_ethosu/infra.py | 228 +++++++++++++ .../reference_system/arm-none-eabi-gcc.cmake | 79 +++++ .../test_ethosu/reference_system/ethosu_55.h | 27 ++ .../test_ethosu/reference_system/ethosu_mod.h | 59 ++++ .../test_ethosu/reference_system/hard_fault.h | 53 +++ .../contrib/test_ethosu/test_codegen.py | 174 ++++++++++ .../test_ethosu/test_encode_constants.py | 2 +- .../contrib/test_ethosu/test_legalize.py | 3 +- .../contrib/test_ethosu/test_networks.py | 65 ++++ .../test_ethosu/test_replace_conv2d.py | 2 +- .../contrib/test_ethosu/test_replace_copy.py | 2 +- .../contrib/test_ethosu/test_scheduler.py | 2 +- tests/python/relay/aot/aot_test_utils.py | 201 +++++++---- tests/python/relay/aot/corstone300.ld | 8 + tests/python/relay/aot/corstone300.mk | 23 +- tests/python/relay/aot/test_crt_aot.py | 21 +- 23 files changed, 1315 insertions(+), 74 deletions(-) create mode 100644 python/tvm/relay/backend/contrib/ethosu/codegen.py create mode 100644 src/relay/backend/contrib/ethosu/source_module.cc create mode 100644 tests/python/contrib/test_ethosu/__init__.py create mode 100644 tests/python/contrib/test_ethosu/reference_system/arm-none-eabi-gcc.cmake create mode 100644 tests/python/contrib/test_ethosu/reference_system/ethosu_55.h create mode 100644 tests/python/contrib/test_ethosu/reference_system/ethosu_mod.h create mode 100644 tests/python/contrib/test_ethosu/reference_system/hard_fault.h create mode 100644 tests/python/contrib/test_ethosu/test_codegen.py create mode 100644 tests/python/contrib/test_ethosu/test_networks.py diff --git a/python/tvm/relay/backend/contrib/ethosu/__init__.py b/python/tvm/relay/backend/contrib/ethosu/__init__.py index 2b424ebb5dec..5fd1a0c19dc9 100644 --- a/python/tvm/relay/backend/contrib/ethosu/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/__init__.py @@ -19,6 +19,7 @@ from . import legalize from . import preprocess from . import errors +from . import codegen from . import vela_api from . import tir_to_cs_translator from .util import partition_for_ethosu diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py new file mode 100644 index 000000000000..e821ea8bf0c4 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Codegen for Arm(R) Ethos(TM)-U""" +import tvm +from tvm import relay +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants +from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU +from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator +from tvm.relay.backend.contrib.ethosu import util + + +@tvm._ffi.register_func("relay.ext.ethosu.constant_updater") +def constant_updater(expr, symbol): # pylint: disable=unused-argument + """ + We dont want the build process to extract constants to be loaded in + the runtime as we are embedding them inside the C runtime.Module. + """ + return dict() + + +@tvm._ffi.register_func("relay.ext.ethosu") +def ethosu_compiler(ref): + """Main function to a compile a given relay function of + NPU compatible operators to generated command stream. + Such generated command stream would be loaded to the runtime + module that interfaces with NPU driver. + """ + assert isinstance(ref, tvm.ir.function.BaseFunc) + func_name = ref.attrs["global_symbol"] + # There should only be a single input + assert len(ref.params) == 1 + input_size = util.calculate_size_bytes(ref.params[0]) + output_size = util.calculate_size_bytes(ref.body) + cmms, encoded_constants, scratch_size = _compile(ref) + ethosu_runtime = tvm._ffi.get_global_func("runtime.module.ethosu.create") + return ethosu_runtime(func_name, cmms, encoded_constants, scratch_size, input_size, output_size) + + +def _compile(ext_func): + """ + This is the main wrapper that accepts an external + relay function and runs all the passes to lower it down + to command stream + Parameters + ---------- + ext_func : tvm.relay.function.Function + The partitioned relay function + Returns + ------- + cs : str + An hex string of the bytes of command stream + encoded_constants : str + An hex string of the bytes that includes concat'd + encoded weights, encoded biases and scales. + scratch_size : int + The size of the scratch buffer needed. + """ + mod = tvm.IRModule() + mod["main"] = ext_func + mod = LegalizeEthosU()(mod) + mod = relay.transform.InferType()(mod) + # We are currently using copy_constants scheduler In the long run, + # this should be a single intelligent and a composite scheduler + # that can perform scheduling based on user inputs such as + # scratch memory size. + tir_mod, params = lower_to_tir(mod["main"], copy_constants()) + cmms, encoded_constants, scratch_size = tir_to_cs_translator.translate(tir_mod, params) + return cmms, encoded_constants, scratch_size diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 82b7f1e68cee..fd58da803623 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -221,3 +221,9 @@ def transform_module( mod = LegalizeSplit()(mod) mod = LegalizeEthosUConv2D()(mod) return mod + + def __call__(self, *args, **kwargs): + # pylint is unable figure out the decorated + # class is callable, thus adding this to + # suppress the warning. + pass diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 0919d3fe7a5f..b5c2179b893b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -197,3 +197,15 @@ def get_dim_value(layout: str, dim: int): if dim_char == dim: return idx return None + + +def calculate_size_bytes(expr): + """This is a helper function to calculate the number + of bytes required to hold the tensor/relay.expr""" + try: + type_info = np.iinfo(expr.checked_type.dtype) + except ValueError: + type_info = np.finfo(expr.checked_type.dtype) + element_size = type_info.bits // 8 + elements = np.prod(list(expr.checked_type.shape)) + return element_size * elements diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index deca3b5a4c5a..fc850e37379c 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -665,7 +665,6 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Apply storage rewrite pass to the runner function to do memory planning auto storage_rewrite = tir::transform::StorageRewrite(); mod_run = storage_rewrite(mod_run); - // The workspace for main function should be calculated after performing storage_rewrite for // the top level TIR function. auto workspace_byte_alignment = diff --git a/src/relay/backend/contrib/ethosu/source_module.cc b/src/relay/backend/contrib/ethosu/source_module.cc new file mode 100644 index 000000000000..61a880e17ffb --- /dev/null +++ b/src/relay/backend/contrib/ethosu/source_module.cc @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../../../../runtime/file_utils.h" + +namespace tvm { +namespace runtime { + +class EthosUModuleNode : public ModuleNode { + public: + /*! + * \brief The ethos runtime module. + * + * \param cmms A array of external symbol 1, serialized command stream 1 + * external symbol 2, serialized command stream 2, .... + * TODO : if and when FFI support Maps with non-objects OR compound arrays + * switch to that. + */ + explicit EthosUModuleNode(const String& func_name_, const String& cmms_hex_, + const String& weights_bias_hex_, const Integer& scratch_size_, + const Integer& input_size_, const Integer& output_size_) { + func_names_.push_back(func_name_); + cmms_hex = std::move(cmms_hex_); + weights_bias_hex = std::move(weights_bias_hex_); + scratch_size = scratch_size_->value; + input_size = input_size_->value; + output_size = output_size_->value; + c_source = GenerateSource(); + } + + /*! + * \brief Save the module to file. + * + * \param file_name The file to be saved to. + * \param format The format of the file. + */ + void SaveToFile(const std::string& file_name, const std::string& format) final { + std::string fmt = GetFileFormat(file_name, format); + LOG(INFO) << "format=" << fmt << ";;\n"; + ICHECK_EQ(fmt, "c") << "Can only save to format=" + << "c"; + std::ofstream out(file_name); + out << c_source; + out.close(); + } + + std::string GetSource(const std::string& format) final { return c_source; } + + std::string GetCS() { return cmms_hex; } + + /*! + * \brief Get a PackedFunc from the module. + * + * \param name The name of the function. + * \param sptr_to_self The ObjectPtr that points to this module node. + * + * \return The function pointer when it is found, otherwise, PackedFunc(nullptr). + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + if (name == "get_func_names") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_; }); + } + return PackedFunc(); + } + + const char* type_key() const override { return "c"; } + + static Module Create(String func_name, String cmms_hex, String weights_bias_hex, + Integer scratch_size, Integer input_size, Integer output_size) { + auto n = make_object(func_name, cmms_hex, weights_bias_hex, scratch_size, + input_size, output_size); + return Module(n); + } + + private: + String c_source; + Array func_names_; + String cmms_hex; + String weights_bias_hex; + size_t scratch_size; + size_t input_size; + size_t output_size; + int indent_{0}; + + /*! + * \brief Convert the raw string of hex values into a hex string + * + * \param raw the raw string of hex values + * + * \return string formatted as a hex string + */ + std::string GetHexString(const std::string& raw) { + std::stringstream ss; + for (size_t i = 0; i < raw.size() / 2; ++i) { + ss << "\\x" << raw.substr(i * 2, 2); + } + return ss.str(); + } + + /*! + * \brief Emit code that updates the base_addrs array with the base address of the given array + * + * \param index array index for base_addrs and base_addrs_size + * \param name of the array containing relevant data + * + * \return string of code that updates the base_addrs array with the base address of the given + * array + */ + std::string SetBaseAddress(int index, std::string name) { + std::stringstream ss; + ss << " base_addrs[" << index << "] = (uintptr_t)(" << name << ");\n"; + ss << " base_addrs_size[" << index << "] = " << name << "_size;\n"; + return ss.str(); + } + + /*! + * \brief Enter a new scope. + */ + void EnterScope() { indent_ += 2; } + + /*! + * \brief Exit a scope. + */ + void ExitScope() { + ICHECK_GE(indent_, 2U) << "Wrong ident found."; + indent_ -= 2; + } + + /*! \brief Print indents using spaces. */ + void PrintIndents(std::stringstream& ss) { + for (int i = 0; i < indent_; i++) { + ss << ' '; + } + } + + /*! + * \brief Creates a runtime function header + */ + void PrintRuntimeFunctionHeader(std::stringstream& ss, std::string func_name) { + ss << "TVM_DLL int32_t "; + ss << func_name << "(void* input, void* output) {\n"; + } + + /*! + * \brief Creates a cplusplus guard prefix for extern "C" printing + */ + void PrintExternCPrefix(std::stringstream& ss) { + PrintIndents(ss); + ss << "#ifdef __cplusplus\n"; + ss << "extern \"C\" {\n"; + ss << "#endif\n"; + } + + /*! + * \brief Creates a cplusplus guard postfix for extern "C" printing + */ + void PrintExternCPostfix(std::stringstream& ss) { + PrintIndents(ss); + ss << "#ifdef __cplusplus\n"; + ss << "}\n"; + ss << "#endif\n"; + } + + /*! + * \brief Emit code that offloads a subgraph to the NPU + * + * \return string of code that offloads a subgraph to the NPU + */ + std::string GenerateSource() { + std::string func_no_dashes = func_names_[0]; + std::replace(func_no_dashes.begin(), func_no_dashes.end(), '-', '_'); + std::stringstream ss; + + ss << "#include \n"; + ss << "#include \n"; + ss << "#include \n"; + ss << "#include \n"; + ss << "#include \n"; + ss << "\n"; + size_t weights_size = (weights_bias_hex.size() / 2); + ss << "static const size_t weights_size = " << std::to_string(weights_size) << ";\n"; + ss << "static const size_t scratch_size = " << std::to_string(scratch_size) << ";\n"; + ss << "// Update linker script to place ethosu_scratch in memory that can be accessed by the " + "NPU\n"; + if (weights_size > 0) { + ss << "__attribute__((section(\"ethosu_scratch\"), aligned(16))) static int8_t weights[" + << weights_size << "] = \""; + ss << GetHexString(weights_bias_hex); + ss << "\";\n"; + } else { + ss << "static int8_t* weights = NULL;\n"; + } + ss << "__attribute__((section(\"ethosu_scratch\"), aligned(16))) static int8_t cms_data_data[" + << cmms_hex.size() / 2 << "] = \""; + ss << GetHexString(cmms_hex); + ss << "\";\n"; + ss << "static const size_t cms_data_size = sizeof(cms_data_data);\n"; + ss << "\n"; + + PrintExternCPrefix(ss); + ss << "static int32_t " << func_no_dashes + "_(int8_t* in0, " + << "size_t in0_size, int8_t* out0, size_t out0_size) {\n"; + ss << " int num_tensors = 5;\n"; + ss << " void* cms_data = (void*)(cms_data_data);\n"; + ss << " int64_t device_type = kDLCPU;\n"; + ss << " int64_t device_id = 0;\n"; + if (scratch_size > 0) { + ss << " int8_t* scratch = (int8_t*) TVMBackendAllocWorkspace(device_type, device_id, " + "(uint64_t)scratch_size, 0, 16);\n"; + } else { + ss << " int8_t* scratch = NULL;\n"; + } + ss << " size_t base_addrs_size[num_tensors];\n"; + ss << " uint64_t base_addrs[num_tensors];\n"; + ss << "\n"; + ss << SetBaseAddress(0, "weights"); + ss << SetBaseAddress(1, "scratch"); + ss << SetBaseAddress(2, "scratch"); + ss << SetBaseAddress(3, "in0"); + ss << SetBaseAddress(4, "out0"); + ss << "\n"; + ss << " struct ethosu_driver *drv = ethosu_reserve_driver();\n"; + ss << " int32_t result = ethosu_invoke(drv, cms_data, cms_data_size, base_addrs, " + "base_addrs_size, " + "num_tensors);\n"; + ss << " ethosu_release_driver(drv);\n"; + if (scratch_size > 0) { + ss << " TVMBackendFreeWorkspace(device_type, device_id, scratch);\n"; + } + ss << " if (result != 0) {\n"; + ss << " return -1;\n"; + ss << " } else {\n"; + ss << " return 0;\n"; + ss << " }\n"; + ss << "}\n"; + ss << "\n"; + PrintExternCPostfix(ss); + ss << "\n"; + PrintExternCPrefix(ss); + ss << "// Wrapper function is provided to allow for easier debugging\n"; + ss << "inline static int32_t " + func_no_dashes + "_wrapper_(void* input, void* output) {\n"; + ss << " size_t input_data_size = " << input_size << ";\n"; + ss << " size_t output_data_size = " << output_size << ";\n"; + ss << " return " + func_no_dashes + + "_((int8_t*)input, input_data_size, (int8_t*)output, output_data_size);\n"; + ss << "}\n"; + PrintExternCPostfix(ss); + ss << "\n"; + PrintExternCPrefix(ss); + PrintRuntimeFunctionHeader(ss, func_names_[0]); + EnterScope(); + PrintIndents(ss); + ss << "return " << func_no_dashes << "_wrapper_(input, output);\n"; + ExitScope(); + ss << "}\n"; + PrintExternCPostfix(ss); + + return ss.str(); + } +}; + +class EthosUModule : public Module { + public: + EthosUModule() {} + explicit EthosUModule(ObjectPtr n) : Module(n) {} + /*! \return internal container */ + inline EthosUModuleNode* operator->(); + /*! \return internal container */ + inline const EthosUModuleNode* operator->() const; +}; + +inline EthosUModuleNode* EthosUModule::operator->() { + return static_cast(get_mutable()); +} + +TVM_REGISTER_GLOBAL("runtime.module.ethosu.create").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = EthosUModuleNode::Create(args[0], args[1], args[2], args[3], args[4], args[5]); +}); + +TVM_REGISTER_GLOBAL("runtime.module.ethosu.getcs").set_body_typed([](EthosUModule mod) { + return mod->GetCS(); +}); + +} // namespace runtime +} // namespace tvm diff --git a/tests/python/contrib/test_ethosu/__init__.py b/tests/python/contrib/test_ethosu/__init__.py new file mode 100644 index 000000000000..e23e5fc926b2 --- /dev/null +++ b/tests/python/contrib/test_ethosu/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test infrastructure for Arm(R) Ethos(TM)-U NPU related tests""" diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index fc795c066cb6..aeed64ad4aec 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -24,15 +24,31 @@ the command stream and perform an equivalency check for single operator test cases. """ +from typing import List +import os +import struct import numpy from enum import IntEnum +from ethosu.vela.register_command_stream_generator import CmdMode +from ethosu.vela.register_command_stream_generator import cmd0 +from ethosu.vela.register_command_stream_generator import cmd1 + import tvm from tvm import relay import tvm.relay.backend.contrib.ethosu.op as ethosu_ops from tvm.topi.nn.utils import get_pad_tuple +from tests.python.relay.aot.aot_test_utils import ( + AOTCompiledTestModel, + AOTDataLinkage, + AOTTestModel, + AOTTestRunner, + compile_models, + run_and_check, +) + class AttachType(IntEnum): kGroupRoot = 1 @@ -42,6 +58,218 @@ class AttachType(IntEnum): kScanUpdate = 5 +class VelaArtifacts: + def __init__(self): + self.cs = dict() + self.flash = dict() + self.sram = dict() + self.npu_ops = set() + + +def parse_relay_tflite_model(tflite_model, input_tensor, input_shape, input_dtype): + mod_, params_ = relay.frontend.from_tflite( + tflite_model, + shape_dict={input_tensor: input_shape}, + dtype_dict={input_tensor: input_dtype}, + ) + return mod_, params_ + + +def parse_tflite_model(model_file): + try: + import tflite + + return tflite.Model.GetRootAsModel(model_file, 0) + except AttributeError: + import tflite.Model + + return tflite.Model.Model.GetRootAsModel(model_file, 0) + + +def print_payload(payload): + cmds = deserialize_command_stream(payload) + for cmd_val in cmds: + cmd, val = parse_cmd(cmd_val) + s = str(cmd) + s = s.ljust(40) + s += str(val) + print(s) + + +def parse_cmd(binary_cmd): + code = binary_cmd[0] & 0x0000FFFF # lower 16 bits + param = binary_cmd[0] >> 16 # higher 16 bits + payload_mode = CmdMode(code & CmdMode.Mask) + if payload_mode == CmdMode.Payload32: + command = cmd1(code & CmdMode.CmdOpMask) + value = binary_cmd[1] + else: + command = cmd0(code & CmdMode.CmdOpMask) + value = param + return command, value + + +def check_cmms_equivalency(vela_cmd, vela_value, tvm_value, ignore_cmds=None): + if ignore_cmds is None: + ignore_cmds = [] + if vela_value != tvm_value and vela_cmd not in ignore_cmds: + raise RuntimeError( + "ValueMismatch :: vela={}, tvm={} for command:{}".format( + vela_value, tvm_value, vela_cmd + ) + ) + + +def verify_cmms(cmms_tvm_blob, cmms_vela_blob): + vela_cmm = deserialize_command_stream(cmms_vela_blob) + tvm_cmm = deserialize_command_stream(cmms_tvm_blob) + cmms_zip = zip(vela_cmm, tvm_cmm) + + first_ifm_found = False + last_ofm_found = False + + ignore_commands = ( + cmd1.NPU_SET_DMA0_SRC, + cmd1.NPU_SET_DMA0_DST, + cmd1.NPU_SET_WEIGHT_BASE, + cmd1.NPU_SET_OFM_BASE0, + cmd1.NPU_SET_IFM_BASE0, + cmd1.NPU_SET_SCALE_BASE, + ) + + ofm_region_params = [] + ofm_bases = [] + for vela_cmm, tvm_cmm in cmms_zip: + vela_cmd, vela_value = parse_cmd(vela_cmm) + tvm_cmd, tvm_value = parse_cmd(tvm_cmm) + + assert vela_cmd == tvm_cmd + + # The first IFM region could be different, but it needs to be 1 and 3. + if vela_cmd == cmd0.NPU_SET_IFM_REGION and not first_ifm_found: + if vela_value == 1 and tvm_value == 3: + first_ifm_found = True + continue + + if vela_cmd == cmd1.NPU_SET_IFM_BASE0 and not first_ifm_found: + if tvm_value != 0: + raise RuntimeError("ValueError :: tvm primary ifm base should be zero") + continue + + # OFM regions should be cached to be checked later + if vela_cmd == cmd0.NPU_SET_OFM_REGION: + ofm_region_params.append((vela_value, tvm_value)) + continue + + # OFM bases should be cached to be checked later + if vela_cmd == cmd1.NPU_SET_OFM_BASE0: + ofm_bases.append((vela_value, tvm_value)) + continue + + check_cmms_equivalency(vela_cmd, vela_value, tvm_value, ignore_commands) + + # The last OFM region could be different but it should be 1 and 4. + last_vela_ofm_region, last_tvm_ofm_region = ofm_region_params.pop(-1) + if not (last_vela_ofm_region == 1 and last_tvm_ofm_region == 4): + raise RuntimeError( + "ValueMismatch :: vela={}, tvm={} for last ofm region it should be 1 and 4 respectively".format( + last_vela_ofm_region, last_tvm_ofm_region + ) + ) + + # The rest of the OFM regions should be the same. + for vela_value, tvm_value in ofm_region_params: + check_cmms_equivalency(vela_cmd, vela_value, tvm_value, ignore_commands) + + # The last OFM base should be zero for tvm + _, last_tvm_ofm_base = ofm_bases.pop(-1) + if not last_tvm_ofm_base == 0: + raise RuntimeError("ValueError :: tvm primary ofm base should be zero") + + +def deserialize_command_stream(blob): + assert isinstance(blob, bytes) + payload_bytes = struct.unpack("<{0}I".format(len(blob) // 4), blob) + cmms = [] + # remove_header + payload_bytes = payload_bytes[8:] + idx = 0 + while idx < len(payload_bytes): + cmd = [] + code = payload_bytes[idx] + idx += 1 + cmd.append(code) + payload_mode = CmdMode(code & CmdMode.Mask) + if payload_mode == CmdMode.Payload32: + value = payload_bytes[idx] + idx += 1 + cmd.append(value) + cmms.append(cmd) + return cmms + + +def _create_test_runner(accel): + file_dir = os.path.dirname(os.path.abspath(__file__)) + test_root = os.path.join(file_dir, "reference_system") + ethosu_macs = accel[accel.rfind("-") + 1 :] + return AOTTestRunner( + makefile="corstone300", + prologue=""" + uart_init(); + EthosuInit(); + """, + includes=["uart.h", "ethosu_55.h", "ethosu_mod.h", "hard_fault.h"], + parameters={"ETHOSU_TEST_ROOT": test_root, "NPU_VARIANT": ethosu_macs}, + pass_config={ + "relay.ext.ethosu.options": { + "accelerator_config": accel, + } + }, + ) + + +def build_source(module, inputs, outputs, accel="ethos-u55-256"): + test_runner = _create_test_runner(accel) + return compile_models( + models=AOTTestModel( + module=module, + inputs=inputs, + outputs=outputs, + output_tolerance=10, + extra_memory_in_bytes=16 * 1024 * 1024, + ), + interface_api="c", + use_unpacked_api=True, + workspace_byte_alignment=16, + pass_config=test_runner.pass_config, + ) + + +def verify_source( + models: List[AOTCompiledTestModel], + accel="ethos-u55-256", +): + """ + This method verifies the generated source from an NPU module by building it and running on an FVP. + """ + interface_api = "c" + test_runner = _create_test_runner(accel) + run_and_check( + models, + test_runner, + interface_api, + workspace_byte_alignment=16, + data_linkage=AOTDataLinkage(section="ethosu_scratch", alignment=16), + ) + + +def flatten_numpy_data(data): + """Flatten the numpy tensor to be single dimensional""" + total_elements = data.size + reshaped_data = numpy.reshape(data, [total_elements]) + return reshaped_data + + def generate_weights_data(shape, dtype): size = 1 for dim in shape: diff --git a/tests/python/contrib/test_ethosu/reference_system/arm-none-eabi-gcc.cmake b/tests/python/contrib/test_ethosu/reference_system/arm-none-eabi-gcc.cmake new file mode 100644 index 000000000000..6aeb0b7cc7c1 --- /dev/null +++ b/tests/python/contrib/test_ethosu/reference_system/arm-none-eabi-gcc.cmake @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if (__TOOLCHAIN_LOADED) + return() +endif() +set(__TOOLCHAIN_LOADED TRUE) + +set(CMAKE_SYSTEM_NAME Generic) +set(CMAKE_C_COMPILER "arm-none-eabi-gcc") +set(CMAKE_CXX_COMPILER "arm-none-eabi-g++") +set(CMAKE_SYSTEM_PROCESSOR "cortex-m55" CACHE STRING "Select Cortex-M architecture. (cortex-m0, cortex-m3, cortex-m33, cortex-m4, cortex-m55, cortex-m7, etc)") + +set(CMAKE_TRY_COMPILE_TARGET_TYPE STATIC_LIBRARY) + +SET(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +SET(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) + +set(CMAKE_C_STANDARD 99) +set(CMAKE_CXX_STANDARD 14) + +# The system processor could for example be set to cortex-m33+nodsp+nofp. +set(__CPU_COMPILE_TARGET ${CMAKE_SYSTEM_PROCESSOR}) +string(REPLACE "+" ";" __CPU_FEATURES ${__CPU_COMPILE_TARGET}) +list(POP_FRONT __CPU_FEATURES CMAKE_SYSTEM_PROCESSOR) + +string(FIND ${__CPU_COMPILE_TARGET} "+" __OFFSET) +if(__OFFSET GREATER_EQUAL 0) + string(SUBSTRING ${__CPU_COMPILE_TARGET} ${__OFFSET} -1 CPU_FEATURES) +endif() + +# Add -mcpu to the compile options to override the -mcpu the CMake toolchain adds +add_compile_options(-mcpu=${__CPU_COMPILE_TARGET}) + +# Set floating point unit +if("${__CPU_COMPILE_TARGET}" MATCHES "\\+fp") + set(FLOAT hard) +elseif("${__CPU_COMPILE_TARGET}" MATCHES "\\+nofp") + set(FLOAT soft) +elseif("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "cortex-m33" OR + "${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "cortex-m55") + set(FLOAT hard) +else() + set(FLOAT soft) +endif() + +add_compile_options(-mfloat-abi=${FLOAT}) +add_link_options(-mfloat-abi=${FLOAT}) + +# Link target +add_link_options(-mcpu=${__CPU_COMPILE_TARGET}) +add_link_options(-Xlinker -Map=output.map) + +# +# Compile options +# +set(cxx_flags "-fno-unwind-tables;-fno-rtti;-fno-exceptions") + +add_compile_options("-Wall;-Wextra;-Wsign-compare;-Wunused;-Wswitch-default;\ +-Wdouble-promotion;-Wredundant-decls;-Wshadow;-Wnull-dereference;\ +-Wno-format-extra-args;-Wno-unused-function;-Wno-unused-label;\ +-Wno-missing-field-initializers;-Wno-return-type;-Wno-format;-Wno-int-conversion" + "$<$:${cxx_flags}>" + ) diff --git a/tests/python/contrib/test_ethosu/reference_system/ethosu_55.h b/tests/python/contrib/test_ethosu/reference_system/ethosu_55.h new file mode 100644 index 000000000000..41ce284956e2 --- /dev/null +++ b/tests/python/contrib/test_ethosu/reference_system/ethosu_55.h @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_CONTRIB_ETHOS_U_ETHOSU_55_H_ +#define TVM_RUNTIME_CONTRIB_ETHOS_U_ETHOSU_55_H_ + +/* Define Arm(R) Ethos(TM)-U55 specific IRQs & base address */ +#define ETHOSU_NPU_FAIL (1 << 4) +#define ETHOSU_IRQ ((IRQn_Type)56) +#define ETHOSU_BASE_ADDRESS ((void*)0x48102000) + +#endif // TVM_RUNTIME_CONTRIB_ETHOS_U_ETHOSU_55_H_ diff --git a/tests/python/contrib/test_ethosu/reference_system/ethosu_mod.h b/tests/python/contrib/test_ethosu/reference_system/ethosu_mod.h new file mode 100644 index 000000000000..aa5c1026bd6d --- /dev/null +++ b/tests/python/contrib/test_ethosu/reference_system/ethosu_mod.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_CONTRIB_ETHOS_U_ETHOSU_MOD_H_ +#define TVM_RUNTIME_CONTRIB_ETHOS_U_ETHOSU_MOD_H_ + +#include +// TODO(@grant-arm): Remove device specific information once RTOS support is available +#include +#include + +#include "ethosu_55.h" + +struct ethosu_driver* ethosu0_driver = ðosu_drv; + +void ethosuIrqHandler0() { ethosu_irq_handler(ethosu0_driver); } + +// Initialize Arm(R) Ethos(TM)-U NPU driver +int EthosuInit() { + if (ethosu_init(ethosu0_driver, (void*)ETHOSU_BASE_ADDRESS, NULL, 0, 1, 1)) { + printf("Failed to initialize NPU.\n"); + return -1; + } + + // Display Arm(R) Ethos(TM)-U version information useful for debugging issues + struct ethosu_version version; + ethosu_get_version(ethosu0_driver, &version); + printf( + "version={major=%u, minor=%u, status=%u}, product={major=%u}, arch={major=%u, minor=%u, " + "patch=%u}\n", + version.id.version_major, version.id.version_minor, version.id.version_status, + version.id.product_major, version.id.arch_major_rev, version.id.arch_minor_rev, + version.id.arch_patch_rev); + printf("macs_per_cc=%u, cmd_stream_version=%u, shram_size=%u\n", version.cfg.macs_per_cc, + version.cfg.cmd_stream_version, version.cfg.shram_size); + + // Assumes SCB->VTOR points to RW memory + NVIC_SetVector(ETHOSU_IRQ, (uint32_t)ðosuIrqHandler0); + NVIC_EnableIRQ(ETHOSU_IRQ); + + return 0; +} + +#endif // TVM_RUNTIME_CONTRIB_ETHOS_U_ETHOSU_MOD_H_ diff --git a/tests/python/contrib/test_ethosu/reference_system/hard_fault.h b/tests/python/contrib/test_ethosu/reference_system/hard_fault.h new file mode 100644 index 000000000000..9d349004848b --- /dev/null +++ b/tests/python/contrib/test_ethosu/reference_system/hard_fault.h @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_CONTRIB_ETHOS_U_HARD_FAULT_H_ +#define TVM_RUNTIME_CONTRIB_ETHOS_U_HARD_FAULT_H_ + +struct ExcContext { + uint32_t r0; + uint32_t r1; + uint32_t r2; + uint32_t r3; + uint32_t r12; + uint32_t lr; + uint32_t pc; + uint32_t xPsr; +}; +void HardFault_Handler() { + int irq; + struct ExcContext* e; + uint32_t sp; + asm volatile( + "mrs %0, ipsr \n" // Read IPSR (Exception number) + "sub %0, #16 \n" // Get it into IRQn_Type range + "tst lr, #4 \n" // Select the stack which was in use + "ite eq \n" + "mrseq %1, msp \n" + "mrsne %1, psp \n" + "mov %2, sp \n" + : "=r"(irq), "=r"(e), "=r"(sp)); + printf("Hard fault. irq=%d, pc=0x%08lu, lr=0x%08lu, xpsr=0x%08lu, sp=0x%08lu\n", irq, e->pc, + e->lr, e->xPsr, sp); + printf("%11s cfsr=0x%08lu bfar=0x%08lu\n", "", SCB->CFSR, SCB->BFAR); + printf("EXITTHESIM\n"); + while (1 == 1) + ; +} + +#endif // TVM_RUNTIME_CONTRIB_ETHOS_U_HARD_FAULT_H_ diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py new file mode 100644 index 000000000000..a0b21a75ef6f --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -0,0 +1,174 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +import pytest + +pytest.importorskip("ethosu.vela") +import os +import numpy as np +import pathlib + +import tvm +import tvm.micro as micro +from tvm import relay +from tvm.relay.backend.contrib import ethosu +from tvm.relay.backend.contrib.ethosu import util +from tests.python.relay.aot.aot_test_utils import generate_ref_data + +from . import relay_ir_builder +from . import infra + +ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32"] + + +def infer_type_function_pass(func): + mod = tvm.IRModule() + mod["test"] = func + mod = relay.transform.InferType()(mod) + return mod["test"] + + +def get_shape_expr(in_expr, out_expr): + main_f = relay.Function([in_expr], out_expr) + main_f = infer_type_function_pass(main_f) + shape = [int(i) for i in main_f.body.checked_type.shape] + return shape + + +@pytest.mark.parametrize( + "accel_type", + ACCEL_TYPES, +) +def test_ethosu_conv2d(accel_type): + def create_graph_single(input_tensor_name, input_tensor_shape, input_tensor_dtype): + c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) + c1_params.ifm.shape = input_tensor_shape + c1_params.kernel.shape = (3, 3, c1_params.ifm.shape[3], 32) + c1_params.kernel.sc = relay.const(np.random.rand(32) * 2, "float32") + c1_params.strides = (1, 1) + c1_params.pad = "VALID" + c1_params.update_output_qnn_params( + input_tensor_dtype, input_tensor_dtype, input_tensor_dtype + ) + input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype) + c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0) + c1_params.ofm.shape = get_shape_expr(input0, c1) + + f = relay.Function([input0], c1) + mod = tvm.IRModule() + mod["main"] = f + return mod, [c1_params] + + def create_graph_double(input_tensor_name, input_tensor_shape, input_tensor_dtype): + c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) + c1_params.ifm.shape = input_tensor_shape + c1_params.kernel.shape = (7, 7, c1_params.ifm.shape[3], 8) + c1_params.strides = (2, 2) + c1_params.pad = "VALID" + c1_params.update_output_qnn_params( + input_tensor_dtype, input_tensor_dtype, input_tensor_dtype + ) + input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype) + c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0) + c1_params.ofm.shape = get_shape_expr(input0, c1) + + c2_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) + c2_params.ifm.shape = c1_params.ofm.shape + c2_params.kernel.shape = (5, 5, c2_params.ifm.shape[3], 16) + c2_params.strides = (1, 1) + c2_params.pad = "SAME" + c2_params.update_output_qnn_params() + c2, new_params = relay_ir_builder.create_qnn_conv2d(c2_params, c1) + c2_params.ofm.shape = get_shape_expr(input0, c2) + + f = relay.Function([input0], c2) + mod = tvm.IRModule() + mod["main"] = f + return mod, [c2_params, c1_params] + + def create_graph_activation(input_tensor_name, input_tensor_shape, input_tensor_dtype): + c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) + c1_params.ifm.shape = input_tensor_shape + c1_params.kernel.shape = (7, 7, c1_params.ifm.shape[3], 8) + c1_params.strides = (2, 2) + c1_params.pad = "VALID" + c1_params.activation = "CLIP" + c1_params.clip_min = 90 + c1_params.clip_max = 110 + c1_params.update_output_qnn_params( + input_tensor_dtype, input_tensor_dtype, input_tensor_dtype + ) + input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype) + c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0) + c1_params.ofm.shape = get_shape_expr(input0, c1) + + c2_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) + c2_params.ifm.shape = c1_params.ofm.shape + c2_params.kernel.shape = (5, 5, c2_params.ifm.shape[3], 16) + c2_params.strides = (1, 1) + c2_params.pad = "SAME" + c2_params.update_output_qnn_params() + c2, new_params = relay_ir_builder.create_qnn_conv2d(c2_params, c1) + c2_params.ofm.shape = get_shape_expr(input0, c2) + + f = relay.Function([input0], c2) + mod = tvm.IRModule() + mod["main"] = f + return mod, [c2_params, c1_params] + + test_cases = [ + (create_graph_single, ["input", (1, 300, 300, 3), "int8"]), + (create_graph_double, ["input", (1, 128, 256, 4), "int8"]), + (create_graph_activation, ["input", (1, 64, 100, 4), "int8"]), + ] + np.random.seed(42) + for test_case in test_cases: + relay_module, conv_params = test_case[0](*test_case[1]) + input_tensor, input_shape, input_dtype = test_case[1] + mod = ethosu.partition_for_ethosu(relay_module) + + # Generate reference data + in_min, in_max = util.get_range_for_dtype_str(input_dtype) + input_data = { + input_tensor: np.random.randint( + in_min, high=in_max, size=input_shape, dtype=input_dtype + ) + } + output_data = generate_ref_data(relay_module, input_data) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 0e546ae2fd24..eb3a4d8cb4da 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -26,7 +26,7 @@ from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute -from infra import make_ethosu_conv2d +from .infra import make_ethosu_conv2d # fmt: off diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 52f6995c3aaa..aad80ece97a8 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -27,7 +27,8 @@ from tvm.relay.backend.contrib.ethosu import legalize, preprocess from tvm.relay.dataflow_pattern import * from tvm.relay.op.contrib.ethosu import * -import relay_ir_builder + +from . import relay_ir_builder def test_split_indices_legalize(): diff --git a/tests/python/contrib/test_ethosu/test_networks.py b/tests/python/contrib/test_ethosu/test_networks.py new file mode 100644 index 000000000000..70ce9c551f2a --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_networks.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +import pytest + +pytest.importorskip("ethosu.vela") +from tests.python.relay.aot.aot_test_utils import ( + convert_to_relay, + generate_ref_data, +) +import numpy as np + +import tvm +import tvm.micro as micro +from tvm import relay +from tvm.relay.backend.contrib import ethosu +from tvm.relay.backend.contrib.ethosu import util +import tvm.relay.testing.tf as tf_testing + +from . import infra + +ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32"] + + +def test_forward_mobilenet_v1(accel_type="ethos-u55-256"): + """Test the Mobilenet V1 TF Lite model.""" + np.random.seed(23) + tflite_model_file = tf_testing.get_workload_official( + "https://storage.googleapis.com/download.tensorflow.org/" + "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz", + "mobilenet_v1_1.0_224_quant.tflite", + ) + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + input_tensor = "input" + input_dtype = "uint8" + input_shape = (1, 224, 224, 3) + in_min, in_max = util.get_range_for_dtype_str(input_dtype) + input_data = np.random.randint(in_min, high=in_max, size=input_shape, dtype=input_dtype) + + relay_mod, params = convert_to_relay(tflite_model_buf, input_data, "input") + input_data = {input_tensor: input_data} + output_data = generate_ref_data(relay_mod, input_data) + + mod = ethosu.partition_for_ethosu(relay_mod, params) + compiled_models = infra.build_source(mod, input_data, output_data, accel_type) + infra.verify_source(compiled_models, accel_type) + + +if __name__ == "__main__": + test_forward_mobilenet_v1() diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 96fe56d1778e..f66b21b92a03 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -24,7 +24,7 @@ from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import total_cascader -from infra import make_ethosu_conv2d, get_convolutional_args +from .infra import make_ethosu_conv2d, get_convolutional_args @pytest.mark.parametrize( diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 222dccacc906..2d76cd654690 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -25,7 +25,7 @@ from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants -from infra import make_ethosu_conv2d +from .infra import make_ethosu_conv2d # fmt: off diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index b07f8ea7f48b..8077271ed496 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -29,7 +29,7 @@ schedule_cache_reads, ) from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_te, extract_constants -from infra import AttachType, make_ethosu_conv2d +from .infra import AttachType, make_ethosu_conv2d class TestTEGraph: diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 0935c0c16e99..746f595a4422 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -26,7 +26,7 @@ import shutil import subprocess import tarfile -from typing import NamedTuple, Union, Optional, List, Dict +from typing import Any, NamedTuple, Union, Optional, List, Dict import pytest import numpy as np @@ -56,17 +56,53 @@ class AOTTestModel(NamedTuple): Dict of input names to value arrays outputs: List[np.array] Ordered list of output value arrays + output_tolerance: Optional[Union[int, float]] + Allowed tolerance of the output name: str Name to use for this model params: Optional[Dict[str, np.array]] Dict of parameter names to value arrays + extra_memory_in_bytes: int + Extra memory to allocate after planned memory """ module: tvm.IRModule inputs: Dict[str, np.array] outputs: List[np.array] + output_tolerance: Optional[Union[int, float]] = None name: str = "default" params: Optional[Dict[str, np.array]] = None + extra_memory_in_bytes: int = 0 + + +class AOTCompiledTestModel(NamedTuple): + """A compiled AOTTestModel with associated module + + Parameters + ---------- + model: AOTTestModel + Input model to be compiled + module: tvm.runtime.Module + The compiled Module for the associated AOTTestModel + """ + + model: AOTTestModel + executor_factory: tvm.relay.backend.executor_factory.AOTExecutorFactoryModule + + +class AOTDataLinkage(NamedTuple): + """A compiled AOTTestModel with associated module + + Parameters + ---------- + section: str + Named section to place data into + alignment: int + Section alignment + """ + + section: str + alignment: int class AOTTestRunner(NamedTuple): @@ -80,14 +116,17 @@ class AOTTestRunner(NamedTuple): Code to prepend to the main function includes: List[str] Additional includes required to run the AOT test runner - parameters: Map[str, str] + parameters: Dict[str, str] Additional parameters to pass to the make command + pass_config: Dict[str, Any] + Additional pass configuration when building the model """ makefile: str = "default" prologue: str = "" includes: List[str] = [] parameters: Dict[str, str] = {} + pass_config: Dict[str, Any] = {} AOT_DEFAULT_RUNNER = AOTTestRunner() @@ -225,11 +264,20 @@ def subprocess_log_output(cmd, cwd, logfile): return proc.wait() -def emit_main_prologue(main_file, custom_prologue, workspace_bytes): +# TODO: Move to linker script with list of symbols rather than coding into source +def emit_data_linkage(output_file, data_linkage): + if data_linkage is not None: + output_file.write( + f'__attribute__((section("{data_linkage.section}"), aligned({data_linkage.alignment}))) ' + ) + + +def emit_main_prologue(main_file, custom_prologue, workspace_bytes, data_linkage): # Add TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES because of memory alignment. main_file.write( f"#define WORKSPACE_SIZE ({workspace_bytes} + TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES)\n" ) + emit_data_linkage(main_file, data_linkage) main_file.write("static uint8_t g_aot_memory[WORKSPACE_SIZE];\n") main_file.write("tvm_workspace_t app_workspace;\n") main_file.write( @@ -242,9 +290,14 @@ def emit_main_prologue(main_file, custom_prologue, workspace_bytes): return StackMemoryManager_Free(&app_workspace,ptr); } -void TVMPlatformAbort(tvm_crt_error_t code) { } +void TVMPlatformAbort(tvm_crt_error_t code) { exit(-1); } -void TVMLogf(const char* msg, ...) { } +void TVMLogf(const char* msg, ...) { + va_list args; + va_start(args, msg); + vfprintf(stdout, msg, args); + va_end(args); +} TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {} int main(){\n @@ -360,23 +413,30 @@ def fake_tensor(source, source_index, packed_index): main_file.write("\n") -def emit_main_compare(main_file, output_list, mod_name): +def emit_main_compare(main_file, output_list, output_tolerance, mod_name): num_outputs = len(output_list) actual_data_name = mangle_name(mod_name, "output_data") expected_data_name = mangle_name(mod_name, "expected_output_data") for i in range(0, num_outputs): is_float_dtype = output_list[i].dtype == "float32" - main_file.write(f"for (int i = 0; i<{actual_data_name}{i}_len; i++){{\n") + + comparison_function = "abs" + tolerance = output_tolerance or 0 if is_float_dtype: - main_file.write( - f'if (fabs({actual_data_name}{i}[i]-{expected_data_name}{i}[i]) > 0.001f){{\n\tprintf("{AOT_FAILURE_TOKEN}\\n");\n\treturn -1;}}\n' - ) - else: - main_file.write( - f'if ({actual_data_name}{i}[i]!={expected_data_name}{i}[i]){{\n\tprintf("{AOT_FAILURE_TOKEN}\\n");\n\treturn -1;}}\n' - ) - main_file.write("}\n") + comparison_function = "fabs" + tolerance = output_tolerance or 0.001 + + main_file.write( + f""" + for (int i = 0; i<{actual_data_name}{i}_len; i++) {{ + if ({comparison_function}({actual_data_name}{i}[i]-{expected_data_name}{i}[i]) > {tolerance}) {{ + printf("{AOT_FAILURE_TOKEN}\\n"); + return -1; + }} + }} + """ + ) def emit_main_init_memory_manager(main_file): @@ -392,6 +452,8 @@ def emit_main_epilogue(main_file): def emit_main_common_includes(main_file, custom_includes): main_file.write("#include \n") + main_file.write("#include \n") + main_file.write("#include \n") main_file.write("#include \n") main_file.write('#include "tvm/runtime/c_runtime_api.h"\n') main_file.write('#include "tvm/runtime/crt/stack_allocator.h"\n') @@ -404,7 +466,14 @@ def emit_main_micro_include(main_file, mod_name): def create_main( - test_name, models, output_path, custom_includes, custom_prologue, interface_api, workspace_bytes + test_name, + models, + output_path, + custom_includes, + custom_prologue, + data_linkage, + interface_api, + workspace_bytes, ): file_path = pathlib.Path(f"{output_path}/" + test_name).resolve() # create header file @@ -418,7 +487,7 @@ def create_main( for model in models: emit_main_data(main_file, model.inputs, model.outputs, model.name) - emit_main_prologue(main_file, custom_prologue, workspace_bytes) + emit_main_prologue(main_file, custom_prologue, workspace_bytes, data_linkage) emit_main_init_memory_manager(main_file) if interface_api == "c": @@ -432,11 +501,11 @@ def create_main( emit_main_packed_call(main_file, model.inputs, model.outputs, model.name) for model in models: - emit_main_compare(main_file, model.outputs, model.name) + emit_main_compare(main_file, model.outputs, model.output_tolerance, model.name) emit_main_epilogue(main_file) -def create_header_file(tensor_name, npy_data, output_path): +def create_header_file(tensor_name, npy_data, output_path, data_linkage): """ This method generates a header file containing the data contained in the numpy array provided. It is used to capture the tensor data (for both inputs and expected outputs) to be bundled into the standalone application. @@ -450,6 +519,8 @@ def create_header_file(tensor_name, npy_data, output_path): header_file.write("#include \n") header_file.write(f"const size_t {tensor_name}_len = {npy_data.size};\n") + emit_data_linkage(header_file, data_linkage) + if npy_data.dtype == "int8": header_file.write(f"int8_t {tensor_name}[] =") elif npy_data.dtype == "int32": @@ -473,57 +544,57 @@ def extract_main_workspace_size_bytes(extract_dir): def compile_models( models: Union[List[AOTTestModel], AOTTestModel], - interface_api, - use_unpacked_api, - workspace_byte_alignment=8, - enable_op_fusion=True, -): + interface_api: str, + use_unpacked_api: bool, + workspace_byte_alignment: int = 8, + enable_op_fusion: bool = True, + pass_config: Dict[str, Any] = None, +) -> List[AOTCompiledTestModel]: """ This method generates runtime.Modules for the tests """ + if not isinstance(models, list): + models = [models] base_target = "c -runtime=c --link-params --executor=aot" extra_target = f"--workspace-byte-alignment={workspace_byte_alignment} --interface-api={interface_api} --unpacked-api={int(use_unpacked_api)}" target = f"{base_target} {extra_target}" - if not isinstance(models, list): - models = [models] - config = {"tir.disable_vectorize": True} + if pass_config: + config = {**config, **pass_config} if not enable_op_fusion: config["relay.FuseOps.max_depth"] = 1 - compiled_runtime_mods = list() + compiled_mods = list() for model in models: with tvm.transform.PassContext(opt_level=3, config=config): - compiled_runtime_mods.append( - tvm.relay.build( - model.module, - target, - target_host=target, - params=model.params, - mod_name=model.name, - ) + executor_factory = tvm.relay.build( + model.module, + target, + target_host=target, + params=model.params, + mod_name=model.name, + ) + compiled_mods.append( + AOTCompiledTestModel(model=model, executor_factory=executor_factory) ) - return compiled_runtime_mods + return compiled_mods def run_and_check( - models: Union[List[AOTTestModel], AOTTestModel], + models: List[AOTCompiledTestModel], runner: AOTTestRunner, - interface_api, - compiled_runtime_mods: List[tvm.runtime.Module], + interface_api: str, debug_calculated_workspaces=False, workspace_byte_alignment=8, + data_linkage: AOTDataLinkage = None, ): """ This method uses the original test data and compiled runtime.Modules to run in the test runner to verify the results. """ - if not isinstance(models, list): - models = [models] - tmp_path = utils.tempdir() tmp_dir = tmp_path.temp_dir @@ -545,12 +616,14 @@ def run_and_check( ) workspace_bytes = 0 - for runtime_module, model in zip(compiled_runtime_mods, models): + for compiled_model in models: + model = compiled_model.model tar_file = os.path.join(base_path, f"{model.name}.tar") - export_model_library_format(runtime_module, tar_file) + export_model_library_format(compiled_model.executor_factory, tar_file) t = tarfile.open(tar_file) t.extractall(base_path) + workspace_bytes += model.extra_memory_in_bytes workspace_bytes += extract_main_workspace_size_bytes(base_path) for key in model.inputs: @@ -559,6 +632,7 @@ def run_and_check( f'{mangle_name(model.name, "input_data")}_{sanitized_tensor_name}', model.inputs[key], include_path, + data_linkage, ) for i in range(len(model.outputs)): @@ -566,19 +640,22 @@ def run_and_check( (f'{mangle_name(model.name,"output_data")}{i}'), np.zeros(model.outputs[i].shape, model.outputs[i].dtype), include_path, + data_linkage, ) create_header_file( (f'{mangle_name(model.name, "expected_output_data")}{i}'), model.outputs[i], include_path, + data_linkage, ) create_main( "test.c", - models, + [compiled_model.model for compiled_model in models], build_path, runner.includes, runner.prologue, + data_linkage, interface_api, workspace_bytes, ) @@ -616,23 +693,29 @@ def run_and_check( def compile_and_run( models: Union[List[AOTTestModel], AOTTestModel], runner: AOTTestRunner, - interface_api, - use_unpacked_api, - debug_calculated_workspaces=False, - workspace_byte_alignment=8, - enable_op_fusion=True, + interface_api: str, + use_unpacked_api: bool, + debug_calculated_workspaces: bool = False, + workspace_byte_alignment: int = 8, + enable_op_fusion: bool = True, + data_linkage: AOTDataLinkage = None, ): """This is a wrapper API to compile and run models as test for AoT""" - compiled_runtime_mods = compile_models( - models, interface_api, use_unpacked_api, workspace_byte_alignment, enable_op_fusion + compiled_test_mods = compile_models( + models=models, + interface_api=interface_api, + use_unpacked_api=use_unpacked_api, + workspace_byte_alignment=workspace_byte_alignment, + enable_op_fusion=enable_op_fusion, + pass_config=runner.pass_config, ) run_and_check( - models, - runner, - interface_api, - compiled_runtime_mods, - debug_calculated_workspaces, - workspace_byte_alignment, + models=compiled_test_mods, + runner=runner, + interface_api=interface_api, + debug_calculated_workspaces=debug_calculated_workspaces, + workspace_byte_alignment=workspace_byte_alignment, + data_linkage=data_linkage, ) diff --git a/tests/python/relay/aot/corstone300.ld b/tests/python/relay/aot/corstone300.ld index 4a6b22480d9f..9534b869f6e6 100644 --- a/tests/python/relay/aot/corstone300.ld +++ b/tests/python/relay/aot/corstone300.ld @@ -257,6 +257,14 @@ SECTIONS __bss_end__ = .; } > DTCM AT > DTCM + .ddr : + { + . = ALIGN(4); + . = ALIGN(16); + *(ethosu_scratch) + . = ALIGN (16); + } > DDR + .data_sram : { . = ALIGN(16); diff --git a/tests/python/relay/aot/corstone300.mk b/tests/python/relay/aot/corstone300.mk index 3a946f2cd876..8d03ccc5b5f4 100644 --- a/tests/python/relay/aot/corstone300.mk +++ b/tests/python/relay/aot/corstone300.mk @@ -28,9 +28,11 @@ endif ARM_CPU=ARMCM55 DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core ETHOSU_PATH=/opt/arm/ethosu +DRIVER_PATH=${ETHOSU_PATH}/core_driver CMSIS_PATH=${ETHOSU_PATH}/cmsis PLATFORM_PATH=${ETHOSU_PATH}/core_platform/targets/corstone-300 PKG_COMPILE_OPTS = -g -Wall -O2 -Wno-incompatible-pointer-types -Wno-format -mcpu=cortex-m55 -mthumb -mfloat-abi=hard -std=gnu99 +CMAKE = /opt/arm/cmake/bin/cmake CC = arm-none-eabi-gcc AR = arm-none-eabi-ar RANLIB = arm-none-eabi-ranlib @@ -40,11 +42,15 @@ PKG_CFLAGS = ${PKG_COMPILE_OPTS} \ -I$(build_dir)/../include \ -I$(CODEGEN_ROOT)/host/include \ -I${PLATFORM_PATH} \ + -I${DRIVER_PATH}/include \ -I${CMSIS_PATH}/Device/ARM/${ARM_CPU}/Include/ \ -I${CMSIS_PATH}/CMSIS/Core/Include \ -I${CMSIS_PATH}/CMSIS/NN/Include \ -I${CMSIS_PATH}/CMSIS/DSP/Include \ - -isystem$(STANDALONE_CRT_DIR)/include \ + -isystem$(STANDALONE_CRT_DIR)/include +DRIVER_CMAKE_FLAGS = -DCMAKE_TOOLCHAIN_FILE=$(ETHOSU_TEST_ROOT)/arm-none-eabi-gcc.cmake \ + -DETHOSU_LOG_SEVERITY=debug \ + -DCMAKE_SYSTEM_PROCESSOR=cortex-m55 PKG_LDFLAGS = -lm -specs=nosys.specs -static -T ${AOT_TEST_ROOT}/corstone300.ld @@ -61,6 +67,11 @@ CMSIS_STARTUP_SRCS = $(shell find ${CMSIS_PATH}/Device/ARM/${ARM_CPU}/Source/*.c CMSIS_NN_SRCS = $(shell find ${CMSIS_PATH}/CMSIS/NN/Source/*/*.c) UART_SRCS = $(shell find ${PLATFORM_PATH}/*.c) +ifdef ETHOSU_TEST_ROOT +ETHOSU_ARCHIVE=${build_dir}/ethosu_core_driver/libethosu_core_driver.a +ETHOSU_INCLUDE=-I$(ETHOSU_TEST_ROOT) +endif + aot_test_runner: $(build_dir)/aot_test_runner $(build_dir)/stack_allocator.o: $(TVM_ROOT)/src/runtime/crt/memory/stack_allocator.c @@ -94,9 +105,14 @@ ${build_dir}/libuart.a: $(UART_SRCS) $(QUIET)$(AR) -cr $(abspath $(build_dir)/libuart.a) $(abspath $(build_dir))/libuart/*.o $(QUIET)$(RANLIB) $(abspath $(build_dir)/libuart.a) -$(build_dir)/aot_test_runner: $(build_dir)/test.c $(build_dir)/crt_backend_api.o $(build_dir)/stack_allocator.o ${build_dir}/libcmsis_startup.a ${build_dir}/libcmsis_nn.a ${build_dir}/libuart.a $(build_dir)/libcodegen.a +${build_dir}/ethosu_core_driver/libethosu_core_driver.a: + $(QUIET)mkdir -p $(@D) + $(QUIET)cd $(DRIVER_PATH) && $(CMAKE) -B $(abspath $(build_dir)/ethosu_core_driver) $(DRIVER_CMAKE_FLAGS) + $(QUIET)cd $(abspath $(build_dir)/ethosu_core_driver) && $(MAKE) + +$(build_dir)/aot_test_runner: $(build_dir)/test.c $(build_dir)/crt_backend_api.o $(build_dir)/stack_allocator.o ${build_dir}/libcmsis_startup.a ${build_dir}/libcmsis_nn.a ${build_dir}/libuart.a $(build_dir)/libcodegen.a $(ETHOSU_ARCHIVE) $(QUIET)mkdir -p $(@D) - $(QUIET)$(CC) $(PKG_CFLAGS) -o $@ -Wl,--whole-archive $^ -Wl,--no-whole-archive $(PKG_LDFLAGS) + $(QUIET)$(CC) $(PKG_CFLAGS) $(ETHOSU_INCLUDE) -o $@ -Wl,--whole-archive $^ -Wl,--no-whole-archive $(PKG_LDFLAGS) clean: $(QUIET)rm -rf $(build_dir)/crt @@ -109,6 +125,7 @@ run: $(build_dir)/aot_test_runner -C cpu0.CFGITCMSZ=15 -C mps3_board.uart0.out_file=\"-\" -C mps3_board.uart0.shutdown_tag=\"EXITTHESIM\" \ -C mps3_board.visualisation.disable-visualisation=1 -C mps3_board.telnetterminal0.start_telnet=0 \ -C mps3_board.telnetterminal1.start_telnet=0 -C mps3_board.telnetterminal2.start_telnet=0 -C mps3_board.telnetterminal5.start_telnet=0 \ + -C ethosu.extra_args="--fast" \ -C ethosu.num_macs=$(NPU_VARIANT) $(build_dir)/aot_test_runner .SUFFIXES: diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 9961cd567fbe..d90c4217c4c3 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -299,13 +299,22 @@ def test_mobilenet(debug_calculated_workspaces, workspace_byte_alignment): interface_api = "c" test_runner = AOT_DEFAULT_RUNNER + # TODO(@Mousius) - Enable memory planning to take into account debug information + debugging_memory_overhead = 1024 * 1024 + mod, params = testing.mobilenet.get_workload(batch_size=1) data_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape] data = np.random.uniform(size=data_shape).astype("float32") inputs = {"data": data} output_list = generate_ref_data(mod, inputs, params) compile_and_run( - AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params), + AOTTestModel( + module=mod, + inputs=inputs, + outputs=output_list, + params=params, + extra_memory_in_bytes=debugging_memory_overhead, + ), test_runner, interface_api, use_unpacked_api, @@ -673,12 +682,12 @@ def @main(%data: Tensor[(1, 4, 4, 4), float32], %weight: Tensor[(4, 4, 3, 3), fl } """ ) - compiled_runtime_modules = compile_models( - AOTTestModel(module=relay_mod, inputs=None, outputs=None), - "c", - True, + compiled_test_mods = compile_models( + models=AOTTestModel(module=relay_mod, inputs=None, outputs=None), + interface_api="c", + use_unpacked_api=True, ) - source = compiled_runtime_modules[0].lib.imported_modules[0].get_source() + source = compiled_test_mods[0].executor_factory.lib.imported_modules[0].get_source() # There should be three allocates created for three primitive relay function # calls in the main for the above relay snippet. assert source.count("TVMBackendAllocWorkspace") == 3 From df50fa3dcf0ba7993944389c2b6a5724b0f77730 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Wed, 29 Sep 2021 11:59:33 -0500 Subject: [PATCH 18/20] [LLVM] Make changes needed for opaque pointers (#9138) * [LLVM] Make changes needed for opaque pointers - Pass value type to all Create.*Load and Create.*GEP functions. - Create type TypedPointer to keep both the address and the pointee's type when buffer pointers etc. are created. - Eliminate calls to getPointerElementType, except one in creating debug info (that seems necessary for the time being). * Fix typo in CodeGenCPU::CreateStructRefPtr * Fix type extraction in CodeGenLLVM::AddAliasInfo * Fix types in ramp-1 vector loads/stores * Fix getting intrinsic name in error message * Return valid pointer from PackClosureData when no data to pack --- src/target/llvm/codegen_cpu.cc | 173 +++++++++++++++++++---------- src/target/llvm/codegen_cpu.h | 6 +- src/target/llvm/codegen_hexagon.cc | 92 +++++++++------ src/target/llvm/codegen_llvm.cc | 122 ++++++++++++-------- src/target/llvm/codegen_llvm.h | 11 +- 5 files changed, 259 insertions(+), 145 deletions(-) diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index c98c23ae8c61..466f85393b1b 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -246,8 +246,9 @@ std::unique_ptr CodeGenCPU::Finish() { } return CodeGenLLVM::Finish(); } -llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, - int kind) { + +CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, + llvm::Value* index, int kind) { if (kind < builtin::kArrKindBound_) { if (buf->getType() == t_void_p_) { buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo()); @@ -257,57 +258,87 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm:: } switch (kind) { case builtin::kArrAddr: { - return builder_->CreateInBoundsGEP(buf, index); + return TypedPointer(t_tvm_array_, builder_->CreateInBoundsGEP(t_tvm_array_, buf, index)); } case builtin::kArrData: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(0)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(0); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(0)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrShape: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(4)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(4); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(4)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrStrides: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(5)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(5); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(5)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrNDim: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(2); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(2)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrTypeCode: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(0)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(0); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(0)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrTypeBits: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(1)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(1); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(1)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrTypeLanes: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(2)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(2); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(2)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrByteOffset: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(6); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(6)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrDeviceId: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(1)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(1); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(1)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrDeviceType: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(0)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(0); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(0)}); + return TypedPointer(member_type, member_addr); } case builtin::kTVMValueContent: { ICHECK_EQ(t.lanes(), 1); ICHECK(t.is_handle() || t.bits() == 64); if (t.is_int()) { buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo()); - return builder_->CreateInBoundsGEP(buf, index); + return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index)); } else if (t.is_float()) { buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo()); - return builder_->CreateInBoundsGEP(buf, index); + return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index)); } else { ICHECK(t.is_handle()); buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo()); - buf = builder_->CreateInBoundsGEP(buf, index); - return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo()); + buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index); + return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo())); } } default: LOG(FATAL) << "unknown field code"; - return nullptr; + return TypedPointer(); } } @@ -373,7 +404,10 @@ llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) { ICHECK(gv != nullptr); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, llvm::Align(gv->getAlignment())); + llvm::LoadInst* faddr = + builder_->CreateAlignedLoad(gv->getValueType(), gv, llvm::Align(gv->getAlignment())); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv->getValueType(), gv, gv->getAlignment()); #else llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment()); #endif @@ -490,10 +524,11 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { builder_->SetInsertPoint(compute_call_end); } -llvm::Value* CodeGenCPU::PackClosureData(const Array& vfields, uint64_t* num_bytes) { +CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const Array& vfields, + uint64_t* num_bytes) { if (vfields.size() == 0) { *num_bytes = 0U; - return llvm::Constant::getNullValue(t_void_p_); + return TypedPointer(t_void_p_, llvm::Constant::getNullValue(t_void_p_)); } std::vector fields; for (Var v : vfields) { @@ -501,23 +536,24 @@ llvm::Value* CodeGenCPU::PackClosureData(const Array& vfields, uint64_t* nu ICHECK(it != var_map_.end()); fields.push_back(it->second->getType()); } - llvm::StructType* tcdata = llvm::StructType::create(fields); - llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1)); + llvm::StructType* ctype = llvm::StructType::create(fields); + llvm::Value* cvalue = builder_->CreateAlloca(ctype, ConstInt32(1)); llvm::Value* zero = ConstInt32(0); for (size_t i = 0; i < vfields.size(); ++i) { builder_->CreateStore(var_map_.at(vfields[i].get()), - builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)})); + builder_->CreateInBoundsGEP(ctype, cvalue, {zero, ConstInt32(i)})); } - *num_bytes = data_layout_->getTypeAllocSize( - llvm::cast(cdata->getType())->getElementType()); - return cdata; + *num_bytes = data_layout_->getTypeAllocSize(ctype); + return TypedPointer(ctype, cvalue); } -void CodeGenCPU::UnpackClosureData(llvm::Value* cdata, const Array& vfields, +void CodeGenCPU::UnpackClosureData(TypedPointer cdata, const Array& vfields, std::unordered_map* vmap) { for (size_t i = 0; i < vfields.size(); ++i) { - (*vmap)[vfields[i].get()] = - builder_->CreateLoad(builder_->CreateInBoundsGEP(cdata, {ConstInt32(0), ConstInt32(i)})); + llvm::Type* field_type = cdata.type->getStructElementType(i); + llvm::Value* field_addr = + builder_->CreateInBoundsGEP(cdata.type, cdata.addr, {ConstInt32(0), ConstInt32(i)}); + (*vmap)[vfields[i].get()] = builder_->CreateLoad(field_type, field_addr); } } @@ -530,21 +566,22 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { // allocate and setup the closure, call the closure. Array vfields = tir::UndefinedVars(body, {}); uint64_t nbytes; - llvm::Value* cdata = PackClosureData(vfields, &nbytes); + TypedPointer cdata = PackClosureData(vfields, &nbytes); #if TVM_LLVM_VERSION >= 90 auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch()); #else auto launch_callee = RuntimeTVMParallelLaunch(); #endif BasicBlock* par_launch_end = CheckCallSuccess(builder_->CreateCall( - launch_callee, {f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)})); + launch_callee, + {f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(num_task)})); // Setup the closure function. BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); llvm::Value* task_id = &(*it++); llvm::Value* penv = &(*it++); - cdata = builder_->CreatePointerCast(&(*it++), cdata->getType()); + cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType()); // setup new variable map, swap it with current var context. std::unordered_map new_vmap; UnpackClosureData(cdata, vfields, &new_vmap); @@ -553,8 +590,9 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { par_env.task_id = Var("task_id", DataType::Int(32)); par_env.num_task = Var("num_task", DataType::Int(32)); new_vmap[par_env.task_id.get()] = task_id; - new_vmap[par_env.num_task.get()] = - builder_->CreateLoad(builder_->CreateInBoundsGEP(penv, {ConstInt32(0), ConstInt32(1)})); + new_vmap[par_env.num_task.get()] = builder_->CreateLoad( + t_int32_, + builder_->CreateInBoundsGEP(t_tvm_parallel_group_env_, penv, {ConstInt32(0), ConstInt32(1)})); par_env.penv = penv; auto new_analyzer = std::make_unique(); std::swap(function_, f); @@ -600,14 +638,14 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod // allocate and setup the closure, call the closure. uint64_t nbytes; Array vfields = tir::UndefinedVars(body, {}); - llvm::Value* cdata = PackClosureData(vfields, &nbytes); + TypedPointer cdata = PackClosureData(vfields, &nbytes); BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall( - finit, {gv, f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(nbytes)})); + finit, {gv, f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(nbytes)})); // Setup the closure function. BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); - cdata = builder_->CreatePointerCast(&(*it++), cdata->getType()); + cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType()); // setup new variable map, swap it with current var context. std::unordered_map new_vmap; UnpackClosureData(cdata, vfields, &new_vmap); @@ -655,7 +693,9 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { BasicBlock* init_block = BasicBlock::Create(*ctx_, "handle_init", function_); BasicBlock* end_block = BasicBlock::Create(*ctx_, "handle_init_end", function_); #if TVM_LLVM_VERSION >= 110 - llvm::Value* handle = builder_->CreateAlignedLoad(hptr, llvm::Align(align)); + llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, llvm::Align(align)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, align); #else llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align); #endif @@ -667,8 +707,11 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { llvm::Value* out = WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); }); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* ctx = - builder_->CreateAlignedLoad(gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment())); + llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + llvm::Align(gv_mod_ctx_->getAlignment())); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + gv_mod_ctx_->getAlignment()); #else llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); #endif @@ -682,7 +725,10 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, GetConstString(fname), out}); init_block = CheckCallSuccess(retcode); #if TVM_LLVM_VERSION >= 110 - llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, llvm::Align(align)); + llvm::Value* loaded_handle = + builder_->CreateAlignedLoad(t_tvm_func_handle_, out, llvm::Align(align)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* loaded_handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, out, align); #else llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align); #endif @@ -709,11 +755,13 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& llvm::Value* stack_value = MakeValue(args[1]); llvm::Value* stack_tcode = MakeValue(args[2]); llvm::Value* arg_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); - llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); + t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), + ConstInt32(begin)); + TypedPointer arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); llvm::Value* ret_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); - llvm::Value* ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), + ConstInt32(end)); + TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); @@ -721,15 +769,18 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& auto call_callee = RuntimeTVMFuncCall(); #endif llvm::Value* call = builder_->CreateCall( - call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, ret_tcode}); + call_callee, + {handle, arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); llvm::BasicBlock* end_block = CheckCallSuccess(call); // Load the return value and cast it to the designated type (r_type). DataType r_api_type = tir::APIType(r_type); - llvm::Value* load_ptr = - builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()); + llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type); + llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo()); #if TVM_LLVM_VERSION >= 110 - llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8)); + llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8); #else llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8); #endif @@ -737,9 +788,11 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& // Load the return type code. #if TVM_LLVM_VERSION >= 110 - pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8)); +#elif TVM_LLVM_VERSION >= 80 + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8); #else - pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, 8); + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8); #endif pc.end_block = end_block; @@ -882,24 +935,24 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::tvm_struct_get())) { ICHECK_EQ(op->args.size(), 3U); int kind = op->args[2].as()->value; - llvm::Value* ref = - this->CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); + TypedPointer ref = + CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); if (kind == builtin::kArrAddr) { - return builder_->CreatePointerCast(ref, t_void_p_); + return builder_->CreatePointerCast(ref.addr, t_void_p_); } else { - return builder_->CreateLoad(ref); + return builder_->CreateLoad(ref.type, ref.addr); } } else if (op->op.same_as(builtin::tvm_struct_set())) { ICHECK_EQ(op->args.size(), 4U); int kind = op->args[2].as()->value; llvm::Value* value = MakeValue(op->args[3]); - llvm::Value* ref = this->CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), - MakeValue(op->args[1]), kind); + TypedPointer ref = CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), + MakeValue(op->args[1]), kind); ICHECK(kind != builtin::kArrAddr); if (value->getType()->isPointerTy()) { - value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType()); + value = builder_->CreatePointerCast(value, ref.type); } - builder_->CreateStore(value, ref); + builder_->CreateStore(value, ref.addr); return ConstInt32(0); } else if (op->op.same_as(builtin::tvm_stack_alloca())) { ICHECK_EQ(op->args.size(), 2U); diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 30e61ea63f12..402189eb374d 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -105,9 +105,9 @@ class CodeGenCPU : public CodeGenLLVM { llvm::Value* RuntimeTVMParallelBarrier(); llvm::Value* CreateStaticHandle(); llvm::Value* GetPackedFuncHandle(const std::string& str); - llvm::Value* PackClosureData(const Array& fields, uint64_t* num_bytes); - llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind); - void UnpackClosureData(llvm::Value* cdata, const Array& fields, + TypedPointer PackClosureData(const Array& fields, uint64_t* num_bytes); + TypedPointer CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind); + void UnpackClosureData(TypedPointer cdata, const Array& fields, std::unordered_map* vmap); // Make packed call. struct PackedCall { diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index d8a64102f9cd..bffb620d49f9 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -75,7 +75,7 @@ class CodeGenHexagon final : public CodeGenLLVM { llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr}; private: - llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, int kind); + TypedPointer CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, int kind); // Check if the call to packed function is successful // if not directly finalize function and pass on return code. @@ -255,7 +255,10 @@ llvm::GlobalVariable* CodeGenHexagon::InitContextPtr(llvm::Type* p_type, std::st llvm::Value* CodeGenHexagon::GetContextPtr(llvm::GlobalVariable* gv) { ICHECK(gv != nullptr); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, llvm::Align(gv->getAlignment())); + llvm::LoadInst* faddr = + builder_->CreateAlignedLoad(gv->getValueType(), gv, llvm::Align(gv->getAlignment())); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv->getValueType(), gv, gv->getAlignment()); #else llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment()); #endif @@ -313,11 +316,13 @@ CodeGenHexagon::PackedCall CodeGenHexagon::MakeCallPackedLowered(const ArrayCreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); - llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); + t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), + ConstInt32(begin)); + TypedPointer arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); llvm::Value* ret_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); - llvm::Value* ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), + ConstInt32(end)); + TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); @@ -325,15 +330,18 @@ CodeGenHexagon::PackedCall CodeGenHexagon::MakeCallPackedLowered(const ArrayCreateCall( - call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, ret_tcode}); + call_callee, + {handle, arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); llvm::BasicBlock* end_block = CheckCallSuccess(call); // Load the return value and cast it to the designated type (r_type). DataType r_api_type = tir::APIType(r_type); - llvm::Value* load_ptr = - builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()); + llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type); + llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo()); #if TVM_LLVM_VERSION >= 110 - llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8)); + llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8); #else llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8); #endif @@ -341,9 +349,11 @@ CodeGenHexagon::PackedCall CodeGenHexagon::MakeCallPackedLowered(const Array= 110 - pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8)); +#elif TVM_LLVM_VERSION >= 80 + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8); #else - pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, 8); + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8); #endif pc.end_block = end_block; @@ -380,7 +390,9 @@ llvm::Value* CodeGenHexagon::GetPackedFuncHandle(const std::string& fname) { BasicBlock* init_block = BasicBlock::Create(*ctx_, "handle_init", function_); BasicBlock* end_block = BasicBlock::Create(*ctx_, "handle_init_end", function_); #if TVM_LLVM_VERSION >= 110 - llvm::Value* handle = builder_->CreateAlignedLoad(hptr, llvm::Align(align)); + llvm::Value* handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, hptr, llvm::Align(align)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, hptr, align); #else llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align); #endif @@ -392,8 +404,11 @@ llvm::Value* CodeGenHexagon::GetPackedFuncHandle(const std::string& fname) { llvm::Value* out = WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); }); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* ctx = - builder_->CreateAlignedLoad(gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment())); + llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + llvm::Align(gv_mod_ctx_->getAlignment())); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + gv_mod_ctx_->getAlignment()); #else llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); #endif @@ -407,7 +422,10 @@ llvm::Value* CodeGenHexagon::GetPackedFuncHandle(const std::string& fname) { llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, GetConstString(fname), out}); init_block = CheckCallSuccess(retcode); #if TVM_LLVM_VERSION >= 110 - llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, llvm::Align(align)); + llvm::Value* loaded_handle = + builder_->CreateAlignedLoad(t_tvm_func_handle_, out, llvm::Align(align)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* loaded_handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, out, align); #else llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align); #endif @@ -514,23 +532,23 @@ llvm::Value* CodeGenHexagon::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::tvm_struct_get())) { ICHECK_EQ(op->args.size(), 3); int kind = op->args[2].as()->value; - llvm::Value* ref = + TypedPointer ref = CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); if (kind == builtin::kArrAddr) { - return builder_->CreatePointerCast(ref, t_void_p_); + return builder_->CreatePointerCast(ref.addr, t_void_p_); } - return builder_->CreateLoad(ref); + return builder_->CreateLoad(ref.type, ref.addr); } else if (op->op.same_as(builtin::tvm_struct_set())) { ICHECK_EQ(op->args.size(), 4); int kind = op->args[2].as()->value; ICHECK(kind != builtin::kArrAddr); - llvm::Value* ref = CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), + TypedPointer ref = CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), MakeValue(op->args[1]), kind); llvm::Value* value = MakeValue(op->args[3]); if (value->getType()->isPointerTy()) { - value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType()); + value = builder_->CreatePointerCast(value, ref.type); } - builder_->CreateStore(value, ref); + builder_->CreateStore(value, ref.addr); return ConstInt32(0); } else if (op->op.same_as(builtin::tvm_stack_alloca())) { ICHECK_EQ(op->args.size(), 2); @@ -549,8 +567,8 @@ llvm::Value* CodeGenHexagon::CreateIntrinsic(const CallNode* op) { return CodeGenLLVM::CreateIntrinsic(op); } -llvm::Value* CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, - int kind) { +CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, + llvm::Value* index, int kind) { static const std::map field_index = { {builtin::kArrData, 0}, {builtin::kArrDeviceType, 1}, {builtin::kArrDeviceId, 1}, {builtin::kArrNDim, 2}, {builtin::kArrTypeCode, 3}, {builtin::kArrTypeBits, 3}, @@ -581,12 +599,13 @@ llvm::Value* CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, ll uint64_t byte_offset; kArrByteOffset } DLTensor; */ - llvm::Value* base_gep = builder_->CreateInBoundsGEP(buf, index, "base_gep"); + llvm::Value* base_gep = builder_->CreateInBoundsGEP(t_tvm_array_, buf, index, "base_gep"); if (kind == builtin::kArrAddr) { - return base_gep; + return TypedPointer(t_void_p_, base_gep); } llvm::Value* field_gep = builder_->CreateInBoundsGEP( - base_gep, {ConstInt32(0), ConstInt32(field_index.at(kind))}, "field_gep"); + t_tvm_array_, base_gep, {ConstInt32(0), ConstInt32(field_index.at(kind))}, "field_gep"); + llvm::Type* field_type = t_tvm_array_->getStructElementType(field_index.at(kind)); switch (kind) { // These fields have no sub-fields. case builtin::kArrData: @@ -594,10 +613,13 @@ llvm::Value* CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, ll case builtin::kArrShape: case builtin::kArrStrides: case builtin::kArrByteOffset: - return field_gep; + return TypedPointer(field_type, field_gep); } - return builder_->CreateInBoundsGEP( - field_gep, {ConstInt32(0), ConstInt32(subfield_index.at(kind))}, "subfield_gep"); + llvm::Value* subfield_gep = builder_->CreateInBoundsGEP( + field_type, field_gep, {ConstInt32(0), ConstInt32(subfield_index.at(kind))}, + "subfield_gep"); + llvm::Type* subfield_type = field_type->getStructElementType(subfield_index.at(kind)); + return TypedPointer(subfield_type, subfield_gep); } if (kind == builtin::kTVMValueContent) { @@ -615,20 +637,20 @@ llvm::Value* CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, ll ICHECK(t.is_handle() || t.bits() == 64); if (t.is_int()) { buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo()); - return builder_->CreateInBoundsGEP(buf, index); + return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index)); } else if (t.is_float()) { buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo()); - return builder_->CreateInBoundsGEP(buf, index); + return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index)); } else { ICHECK(t.is_handle()); buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo()); - buf = builder_->CreateInBoundsGEP(buf, index); - return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo()); + buf = builder_->CreateInBoundsGEP(t_void_p_, buf, index); + return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo())); } } assert(!"Unknown kind"); - return nullptr; + return TypedPointer(); } namespace { diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 6aabdc1bd804..12fbf2c3e42c 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -473,9 +473,16 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, P meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta); // Extract the underlying type of the allocated buffer. - llvm::Type* buf_type = GetVarValue(buffer)->getType()->getScalarType(); - if (buf_type->isPointerTy()) { - buf_type = buf_type->getPointerElementType(); + DataType dtype = buffer->dtype; + if (buffer->type_annotation.defined()) { + Type element_type = Downcast(buffer->type_annotation)->element_type; + if (auto* ptype = element_type.as()) { + dtype = ptype->dtype; + } + } + llvm::Type* buf_type = DTypeToLLVMType(dtype); + if (!buf_type) { + buf_type = t_void_p_; } std::string tmp; @@ -737,14 +744,17 @@ llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { return ptr; } -llvm::Value* CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index) { +CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* buffer, + llvm::Value* index) { llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); ICHECK(btype != nullptr); - llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo(btype->getAddressSpace()); - if (btype != ptype) { - buffer = builder_->CreatePointerCast(buffer, ptype); + llvm::Type* llvm_type = DTypeToLLVMType(t); + llvm::PointerType* ttype = llvm_type->getPointerTo(btype->getAddressSpace()); + if (btype != ttype) { + buffer = builder_->CreatePointerCast(buffer, ttype); } - return builder_->CreateInBoundsGEP(buffer, index); + llvm::Value* ptr = builder_->CreateInBoundsGEP(llvm_type, buffer, index); + return TypedPointer(llvm_type, ptr); } llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const { @@ -861,10 +871,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { : llvm::Type::getVoidTy(*ctx_); llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " -#if TVM_LLVM_VERSION <= 130 - << llvm::Intrinsic::getName(id, {}); +#if TVM_LLVM_VERSION >= 130 + << llvm::Intrinsic::getBaseName(id).str(); #else - << llvm::Intrinsic::getName(id, return_type, {}); + << llvm::Intrinsic::getName(id, {}); #endif return builder_->CreateCall(f, arg_value); } else if (op->op.same_as(builtin::bitwise_and())) { @@ -888,18 +898,16 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); ICHECK(op->args.size() == 1 && l); - const RampNode* r = l->index.as(); - llvm::Value* ptr; - unsigned addrspace; - if (!r) { - ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(l->index)); - addrspace = llvm::dyn_cast(ptr->getType())->getAddressSpace(); - } else { + TypedPointer buffer_ptr; + if (const RampNode* r = l->index.as()) { PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes); - ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(index)); - addrspace = llvm::dyn_cast(ptr->getType())->getAddressSpace(); + buffer_ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(index)); + } else { + buffer_ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(l->index)); } - return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace)); + unsigned addrspace = + llvm::dyn_cast(buffer_ptr.addr->getType())->getAddressSpace(); + return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace)); } else if (op->op.same_as(builtin::reinterpret()) && is_zero(op->args[0])) { return llvm::Constant::getNullValue(t_void_p_); } else if (op->op.same_as(builtin::isnullptr())) { @@ -1154,29 +1162,40 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { if (t.lanes() == 1) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits); - llvm::Value* ptr = CreateBufferPtr(t, buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(t, buffer, index); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, + llvm::Align(alignment), is_volatile); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* load = + builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); #else - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); #endif AddAliasInfo(load, op->buffer_var.get(), op->index); return load; } else { // vector load - unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); if (const RampNode* ramp = op->index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); ICHECK_EQ(ramp->lanes, t.lanes()); - llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); - ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); + // The index argument is element-based, to create buffer pointer for t's element type. + TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); + unsigned addrspace = + llvm::dyn_cast(buffer->getType())->getAddressSpace(); + buffer_ptr.type = DTypeToLLVMType(t); + buffer_ptr.addr = + builder_->CreatePointerCast(buffer_ptr.addr, buffer_ptr.type->getPointerTo(addrspace)); #if TVM_LLVM_VERSION >= 110 + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, + llvm::Align(alignment), is_volatile); +#elif TVM_LLVM_VERSION >= 80 llvm::LoadInst* load = - builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile); + builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); #else - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); #endif AddAliasInfo(load, op->buffer_var.get(), op->index); return load; @@ -1187,11 +1206,15 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { int basic_align = t.bits() / 8; llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(t)); auto f = [&](int i, llvm::Value* index) { - llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, index); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(basic_align), is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, + llvm::Align(basic_align), is_volatile); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* load = + builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, basic_align, is_volatile); #else - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, basic_align, is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, basic_align, is_volatile); #endif ret = builder_->CreateInsertElement(ret, load, ConstInt32(i)); AddAliasInfo(load, op->buffer_var.get(), PrimExpr()); @@ -1271,30 +1294,36 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { if (t.lanes() == 1) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits); - llvm::Value* ptr = CreateBufferPtr(t, buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(t, buffer, index); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = - builder_->CreateAlignedStore(value, ptr, llvm::Align(alignment), is_volatile); + builder_->CreateAlignedStore(value, buffer_ptr.addr, llvm::Align(alignment), is_volatile); #else - llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile); + llvm::StoreInst* store = + builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile); #endif AddAliasInfo(store, op->buffer_var.get(), op->index); return; } else { // vector store - unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); if (const RampNode* ramp = op->index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); ICHECK_EQ(ramp->lanes, t.lanes()); - llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); - ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); + // The index argument is element-based, to create buffer pointer for t's element type. + TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); + unsigned addrspace = + llvm::dyn_cast(buffer->getType())->getAddressSpace(); + buffer_ptr.type = DTypeToLLVMType(t); + buffer_ptr.addr = + builder_->CreatePointerCast(buffer_ptr.addr, buffer_ptr.type->getPointerTo(addrspace)); #if TVM_LLVM_VERSION >= 110 - llvm::StoreInst* store = - builder_->CreateAlignedStore(value, ptr, llvm::Align(alignment), is_volatile); + llvm::StoreInst* store = builder_->CreateAlignedStore(value, buffer_ptr.addr, + llvm::Align(alignment), is_volatile); #else - llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile); + llvm::StoreInst* store = + builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile); #endif AddAliasInfo(store, op->buffer_var.get(), op->index); return; @@ -1305,13 +1334,14 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { // scalarized store. int basic_align = t.bits() / 8; auto f = [&](int i, llvm::Value* index) { - llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, index); #if TVM_LLVM_VERSION >= 110 - llvm::StoreInst* store = builder_->CreateAlignedStore( - builder_->CreateExtractElement(value, i), ptr, llvm::Align(basic_align), is_volatile); + llvm::StoreInst* store = + builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), buffer_ptr.addr, + llvm::Align(basic_align), is_volatile); #else - llvm::StoreInst* store = builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), - ptr, basic_align, is_volatile); + llvm::StoreInst* store = builder_->CreateAlignedStore( + builder_->CreateExtractElement(value, i), buffer_ptr.addr, basic_align, is_volatile); #endif AddAliasInfo(store, op->buffer_var.get(), PrimExpr()); }; diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index a4f007aeebed..177b53056354 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -181,6 +181,15 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const EvaluateNode* op) override; protected: + /*! + * \brief Address and type pair to assist in handling opaque pointers. + */ + struct TypedPointer { + TypedPointer() = default; + TypedPointer(llvm::Type* t, llvm::Value* a) : type(t), addr(a) {} + llvm::Type* type = nullptr; /*!< Type of the value pointed to. */ + llvm::Value* addr = nullptr; /*!< Address of the value. */ + }; /*! \brief The storage information */ struct StorageInfo { /*! \brief The alignment of allocation */ @@ -301,7 +310,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); - llvm::Value* CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index); + TypedPointer CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index); // Vector concatenation. llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent); llvm::Value* CreateVecFlip(llvm::Value* vec); From d7a28f9cbb6d5fe0130ab3e32bb1ad72e6d879f0 Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Wed, 29 Sep 2021 10:00:31 -0700 Subject: [PATCH 19/20] [Relay] Merge analysis/context_analysis.cc and transforms/device_annotation.cc (#9038) * [Relay] Merge analysis/context_analysis.cc and transforms/device_annotation.cc Currently LowerTEPass (backend/te_compiler.cc) is a 'special' pass because it depends on a side-input DeviceMap. We'd like to remove that side-input, and instead recover the Device (and, ultimately, Target) for each (fused) primitive call from the AST alone. By doing so we also avoid needing to perform device planning twice: - It needs to be done before lowering so we know which primitives need to be compiled for which devices. - It then needs to be re-done after lowering and optimization as a prelude to memory planning. By baking the device plan into the AST we can simply do device planning before lowering, and run memory planning later, both as ordinary passes. While working on that issue we realized we currently have 3 'device planners': - transforms/device_annotation.cc, which supports only a small subset of Relay and uses a simple top-down algorithm to assign a device to every sub-expression. - analysis/context_analysis.cc, which makes a galant effort to support most of Relay, is based on unification rather than a top-down algorithm, but handles higher order functions by ad-hoc and fragile inlining. - transforms/annotate_target.cc, which works on Targets instead of Devices, but is otherwise like 'device planning'. We'd like to bring these together. In this PR we introduce a new transforms/device_planner.cc intended to replace transforms/device_annotation.cc and analysis/context_analysis.cc. We don't delete those two just yet since we need to switch all users off of them in the next PR. We also leave transforms/annotate_target.cc alone pending a proper RFC to bring devices and targets together sensibly, but have it firmly in our sights. transforms/device_planner.cc is based on analysis/context_analysis.cc, but is heavily reworked to: 1. Handle starting from existing "on_device" annotations as well as existing "device_copy" calls. 2. Be idempotent, with the idea we'll probably need to re-run it to 'refine' device planning to account for storge scopes. 3. Robustly handle all of Relay, particularly higher-order functions. For that we replace the inlining approach in analysis/context_analysis.cc with a higher-order unification domain. 4. Be a little more systematic with defaulting. 5. Capture the result of the analysis within the AST as new "device_copy" calls at device boundaries, and new/replaced "on_device" calls wherever the device for a sub-expression is not already 'obvious' from the sub-expression's lexical scope. 6. Provide helper visitors for passes which need to ask for the device for any sub-expression they are processing and/or preserve device information on rewrites. Those passes include: - backend/aot_executor_codegen.cc (AOTOnDemandAllocator) - backend/graph_plan_memory.cc (StorageAllocaBaseVisitor etc) - backend/te_compiler.cc (LowerTensorExprMutator) - backend/vm/lambda_lift.cc (LambdaLifter) - transforms/memory_alloc.cc (DialectRewriter) - transforms/to_a_normal_form.cc (Fill) - backend/vm/compiler.cc (VMFunctionCompiler) However we won't change any of those in this PR. See the draft https://github.com/apache/tvm/pull/8788 for the end game. * [checkpoint] Use Relay script for all unit tests. * [checkpoint] Hoist out DeviceDomain and DeviceDomains. * [checkpoint] Hoist out visitors * [checkpoint] Woops, left debug-only code in --- include/tvm/relay/transform.h | 11 + python/tvm/relay/transform/transform.py | 10 + src/relay/op/annotation/annotation.h | 3 - src/relay/transforms/device_aware_visitors.cc | 285 ++++ src/relay/transforms/device_aware_visitors.h | 317 ++++ src/relay/transforms/device_domains.cc | 482 ++++++ src/relay/transforms/device_domains.h | 304 ++++ src/relay/transforms/device_planner.cc | 1123 ++++++++++++++ .../relay/transforms/device_domains_test.cc | 71 + tests/python/relay/test_pass_plan_devices.py | 1320 +++++++++++++++++ 10 files changed, 3923 insertions(+), 3 deletions(-) create mode 100644 src/relay/transforms/device_aware_visitors.cc create mode 100644 src/relay/transforms/device_aware_visitors.h create mode 100644 src/relay/transforms/device_domains.cc create mode 100644 src/relay/transforms/device_domains.h create mode 100644 src/relay/transforms/device_planner.cc create mode 100644 tests/cpp/relay/relay/transforms/device_domains_test.cc create mode 100644 tests/python/relay/test_pass_plan_devices.py diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index cdd4c9c1dbd2..e740776d6d4f 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -444,6 +444,17 @@ TVM_DLL Pass RelayToTIRTargetHook(); */ TVM_DLL Pass ManifestAlloc(Target target_host, Map targets); +/*! + * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the device on which + * every Relay sub-expression should run (and the result stored). Captures the result of that + * analysis using new "on_device" and "device_copy" CallNodes. See + * tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator} + * for help recovering the device for an arbitrary sub-expression in downstream transformations. + * + * \param default_device_type DLDeviceType for default device. + */ +TVM_DLL Pass PlanDevices(DLDeviceType default_device_type); + } // namespace transform /*! diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 7c79464bdd30..bb91afc06195 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1167,6 +1167,16 @@ def SimplifyExpr(): return _ffi_api.SimplifyExpr() +def PlanDevices(default_device): + """ + Uses existing "on_device" and "device_copy" CallNodes to infer the device on which + every Relay sub-expression should run (and the result stored). Captures the result of that + analysis using new "on_device" and "device_copy" CallNodes. Note that the device_id of + the default_device is ignored. + """ + return _ffi_api.PlanDevices(default_device) + + def FoldExplicitPadding(): """ FoldExplicitPadding finds explict padding before an op that can support diff --git a/src/relay/op/annotation/annotation.h b/src/relay/op/annotation/annotation.h index 643a82116b5b..35f8b6bf50b6 100644 --- a/src/relay/op/annotation/annotation.h +++ b/src/relay/op/annotation/annotation.h @@ -81,9 +81,6 @@ OnDeviceProps GetOnDeviceProps(const CallNode* call_node); */ OnDeviceProps GetOnDeviceProps(const Expr& expr); -/*! \brief Returns true if \p expr is an on_device CallNode. */ -inline bool IsOnDeviceCall(const Expr& expr) { return GetOnDeviceProps(expr).body.defined(); } - /*! * \brief Returns \p function annotated with "param_device_types" and "result_device_type" * attributes capturing parameter and result devices types respectively. diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc new file mode 100644 index 000000000000..204bce53207b --- /dev/null +++ b/src/relay/transforms/device_aware_visitors.cc @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/device_aware_visitors.cc + * \brief Visitors which track the device for the current Relay expression and Relay Vars. + */ + +#include "./device_aware_visitors.h" + +namespace tvm { +namespace relay { +namespace transform { + +// TODO(mbs): We'd probably have less tendious code duplication if we redefined the memoizing +// mutator on top of the generic Functor. + +DLDeviceType LexicalOnDeviceMixin::GetInScopeDeviceType(const Expr& expr) const { + auto props = GetOnDeviceProps(expr); + if (props.body.defined() && props.is_fixed) { + // Look through any fixed "on_device" annotations. + return props.device_type; + } + if (expr->IsInstance()) { + // Lookup variable binding. + auto itr = var_device_types_.find(Downcast(expr)); + if (itr == var_device_types_.end()) { + return kInvalidDeviceType; + } else { + return itr->second; + } + } + // Otherwise use the currently in-scope device type. + if (expr_device_types_.empty()) { + return kInvalidDeviceType; + } else { + return expr_device_types_.back(); + } +} + +void LexicalOnDeviceMixin::EnterFunctionBody() { ++function_nesting_; } + +void LexicalOnDeviceMixin::ExitFunctionBody() { + ICHECK_GT(function_nesting_, 0); + --function_nesting_; +} + +void LexicalOnDeviceMixin::PushDeviceType(DLDeviceType device_type) { + if (device_type == kInvalidDeviceType) { + return; + } + expr_device_types_.emplace_back(device_type); +} + +void LexicalOnDeviceMixin::PopDeviceType() { + if (expr_device_types_.empty()) { + return; + } + expr_device_types_.pop_back(); +} + +void LexicalOnDeviceMixin::PushBoundVar(Var var, DLDeviceType device_type) { + if (device_type == kInvalidDeviceType) { + return; + } + ICHECK(var_device_types_.find(var) == var_device_types_.end()); + var_device_types_.emplace(std::move(var), device_type); +} + +void LexicalOnDeviceMixin::PopBoundVar(const Var& var) { + auto itr = var_device_types_.find(var); + if (itr == var_device_types_.end()) { + return; + } + var_device_types_.erase(itr); +} + +void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + } +} + +void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec). + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(inner_let_node); + expr = inner_let_node->body; + } + + VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + PopBoundVar((*itr)->var); + PostVisitLet_(*itr); + } + PostVisitLetBlock_(let_node); +} + +void DeviceAwareExprVisitor::VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + } else { + DeviceAwareVisitExpr_(call_node); + } +} + +void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const FunctionNode* function_node) { + ExprVisitor::VisitExpr_(function_node); +} + +void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const CallNode* call_node) { + ExprVisitor::VisitExpr_(call_node); +} + +void DeviceAwareExprVisitor::PreVisitLetBlock_(const LetNode* let_node) { + // no-op +} + +void DeviceAwareExprVisitor::PreVisitLetBinding_(const Var& var, const Expr& value) { + VisitExpr(var); + VisitExpr(value); +} + +void DeviceAwareExprVisitor::PostVisitLet_(const LetNode* let_node) { + // no-op +} + +void DeviceAwareExprVisitor::PostVisitLetBlock_(const LetNode* let_node) { + // no-op +} + +Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + return DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + Expr result = DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + + return result; + } +} + +Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector> bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec.) + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + std::pair pair = PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(pair.first, pair.second, inner_let_node->span, inner_let_node); + expr = inner_let_node->body; + } + + expr = VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + const LetNode* pre_let_node = std::get<3>(*itr); + PopBoundVar(pre_let_node->var); + Let post_let = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), + /*body=*/expr, /*span=*/std::get<2>(*itr)); + expr = PostVisitLet_(pre_let_node, post_let.get()); + } + return PostVisitLetBlock_(let_node, expr.as()); +} + +Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + Expr expr = VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + return OnDevice(expr, props.device_type, props.is_fixed); + } else { + return DeviceAwareVisitExpr_(call_node); + } +} + +Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const FunctionNode* function_node) { + return ExprMutator::VisitExpr_(function_node); +} + +Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const CallNode* call_node) { + return ExprMutator::VisitExpr_(call_node); +} + +void DeviceAwareExprMutator::PreVisitLetBlock_(const LetNode* let_node) { /* no-op */ +} + +std::pair DeviceAwareExprMutator::PreVisitLetBinding_(const Var& var, + const Expr& value) { + return std::make_pair(Downcast(VisitExpr(var)), VisitExpr(value)); +} + +Expr DeviceAwareExprMutator::PostVisitLet_(const LetNode* pre_let_node, + const LetNode* post_let_node) { + if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value && + pre_let_node->body == post_let_node->body) { + return GetRef(pre_let_node); + } else { + return GetRef(post_let_node); + } +} + +Expr DeviceAwareExprMutator::PostVisitLetBlock_(const LetNode* pre_let_node, + const LetNode* post_let_node) { + if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value && + pre_let_node->body == post_let_node->body) { + return GetRef(pre_let_node); + } else { + return GetRef(post_let_node); + } +} + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/device_aware_visitors.h b/src/relay/transforms/device_aware_visitors.h new file mode 100644 index 000000000000..8611f87efa06 --- /dev/null +++ b/src/relay/transforms/device_aware_visitors.h @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/device_aware_visitors.h + * \brief Visitors which track the device for the current Relay expression and Relay Vars. + */ + +#ifndef TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_ +#define TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_ + +#include +#include +#include +#include + +#include +#include +#include + +#include "../op/annotation/annotation.h" + +namespace tvm { +namespace relay { +namespace transform { + +/*! + * \brief Helper class for expression transformers which need to keep track of the device + * holding the results of expressions and bound variables. This is recovered from the + * "on_device" function attributes and fixed "on_device" CallNodes added by the PlanDevices + * pass. + * + * \sa \p DeviceAwareExpr{Visitor,Mutator}. + */ +class LexicalOnDeviceMixin { + protected: + /*! + * \brief Returns the device type on which the result of \p expr should/will be stored, assuming + * Push/Pop DeviceType/BoundVar have been correctly called. Returns \p kInvalidDeviceType if + * stack is empty and no bound vars have device types. + */ + DLDeviceType GetInScopeDeviceType(const Expr& expr) const; + + /*! \brief Indicate a function body is being entered. */ + void EnterFunctionBody(); + + /*! \brief Indicate a function body has been processed. */ + void ExitFunctionBody(); + + /*! \brief Push a device type onto the lexical device stack. Ignore if \p kInvalidDeviceType. */ + void PushDeviceType(const DLDeviceType device_type); + + /*! \brief Pop a device type from the lexical device stack. Ignore if stack is empty. */ + void PopDeviceType(); + + /*! \brief Remember that \p var will be stored on \p device_type. Ignore if \p kInvalidDeviceType. + * + * CAUTION: Despite the name we don't support re-entering the same function body. + */ + void PushBoundVar(Var var, DLDeviceType device_type); + + /*! \brief Remove the binding for \p var to it's device type. Ignore if var is not bound. */ + void PopBoundVar(const Var& var); + + /*! + * \brief Returns the number of function definitions wrapping the currently visited expression. + */ + int function_nesting() const { return function_nesting_; } + + private: + /*! + * \brief The number of function bodies entered. Since many transforms need to distinguish global + * functions from local functions this supports the mixin's \p is_global() helper method. + */ + int function_nesting_ = 0; + + /*! + * \brief The stack of lexically enclosing "on_device" devices types, from outermost to innermost. + * When visiting an expression other than a variable we can assume the expression result is + * to be stored on device_type_.back(). + */ + std::vector expr_device_types_; + /*! + * \brief A map from in-scope variable to their device types. We may assume the variable is only + * ever bound to a value stored on this device at runtime. + */ + std::unordered_map + var_device_types_; +}; + +template +class DeviceAwareExprFunctor; + +/*! + * \brief ExprFunctor which tracks devices. We only support 'visitor' style implementation + * with no additional arguments, thus this is equivalent to \p DeviceAwareExprVisitor without + * any memoization. + */ +template <> +class DeviceAwareExprFunctor : public ExprFunctor, + public LexicalOnDeviceMixin { + private: + using TSuper = ExprFunctor; + + public: + void VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + return DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + } + } + + void VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec.) + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(inner_let_node); + expr = inner_let_node->body; + } + + VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + const LetNode* visited_let_node = *itr; + PopBoundVar(visited_let_node->var); + PostVisitLet_(visited_let_node); + } + PostVisitLetBlock_(let_node); + } + + void VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + } else { + DeviceAwareVisitExpr_(call_node); + } + } + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + + virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node) { + return TSuper::VisitExpr_(function_node); + } + + virtual void DeviceAwareVisitExpr_(const CallNode* call_node) { + return TSuper::VisitExpr_(call_node); + } + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node) { /* no-op */ + } + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual void PreVisitLetBinding_(const Var& var, const Expr& value) { + VisitExpr(var); + VisitExpr(value); + } + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLet_(const LetNode* let_node) { /* no-op */ + } + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLetBlock_(const LetNode* let_node) {} +}; + +/*! \brief ExprVisitor which tracks devices. */ +class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { + public: + using ExprVisitor::VisitExpr_; + + void VisitExpr_(const FunctionNode* function_node) final; + void VisitExpr_(const LetNode* let_node) final; + void VisitExpr_(const CallNode* call_node) final; + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node); + virtual void DeviceAwareVisitExpr_(const CallNode* call_node); + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node); + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual void PreVisitLetBinding_(const Var& var, const Expr& value); + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLet_(const LetNode* let_node); + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLetBlock_(const LetNode* let_node); +}; + +/*! \brief ExprMutator which tracks devices. */ +class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { + public: + Expr VisitExpr_(const FunctionNode* function_node) final; + Expr VisitExpr_(const LetNode* let_node) final; + Expr VisitExpr_(const CallNode* call_node) final; + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + virtual Expr DeviceAwareVisitExpr_(const FunctionNode* function_node); + virtual Expr DeviceAwareVisitExpr_(const CallNode* call_node); + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node); + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual std::pair PreVisitLetBinding_(const Var& var, const Expr& value); + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation just returns a reference to the post-visited node. + */ + virtual Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node); + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation returns reference to let node. + */ + virtual Expr PostVisitLetBlock_(const LetNode* pre_let_node, const LetNode* post_let_node); +}; + +} // namespace transform +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_ diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc new file mode 100644 index 000000000000..15784856edbf --- /dev/null +++ b/src/relay/transforms/device_domains.cc @@ -0,0 +1,482 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/device_domains.cc + * \brief Unification domain for the device planner. + */ + +#include "./device_domains.h" + +#include + +#include "../op/annotation/annotation.h" +#include "../op/memory/device_copy.h" + +namespace tvm { +namespace relay { +namespace transform { + +namespace { + +// Ye olde boost hash mixer. +constexpr size_t mix(size_t h1, size_t h2) { + return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)); +} + +/*! + * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR primitives rather + * than the original "device_copy" operator. + * + * See te_compiler.cc for where this rewriting occurs. + */ +DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) { + auto tir_call_attrs = call_node->attrs.as(); + if (tir_call_attrs == nullptr) { + return {}; + } + if (tir_call_attrs->metadata.count("source_device") != 1 || + tir_call_attrs->metadata.count("dst_device") != 1) { + return {}; + } + ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1"; + return { + call_node->args[0], + static_cast( + Downcast(tir_call_attrs->metadata["source_device"])->value), + static_cast(Downcast(tir_call_attrs->metadata["dst_device"])->value)}; +} + +} // namespace + +// The following hash and equality helpers give each free first-order domain pointer its own +// distinct identity. + +size_t DeviceDomainHash::operator()(const DeviceDomainPtr& domain) const { + if (domain->is_free()) { + // Give each free first-order domain its own identity. + return static_cast(reinterpret_cast(domain.get())); + } else { + size_t h = domain->args_and_result_.size(); + h = mix(h, std::hash()(static_cast(domain->device_type_))); + for (const auto& sub_domain_ptr : domain->args_and_result_) { + h = mix(h, DeviceDomainHash()(sub_domain_ptr)); + } + return h; + } +} + +bool DeviceDomainEqual::operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const { + if (lhs->args_and_result_.size() != rhs->args_and_result_.size()) { + // Mismatched arities are never equal. + // (Though we'll never ask to do such a comparison explicitly, the hash map + // may do so implicitly due to hash collisions.) + return false; + } + if (lhs->is_free() && rhs->is_free()) { + // Compare first-order free domains by their address. + return lhs.get() == rhs.get(); + } + if (lhs->args_and_result_.empty()) { + // Compare first-order domains by their device type -- free vs bound will compare as false. + return lhs->device_type_ == rhs->device_type_; + } else { + // Compare higher-order domains pointwise. + for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { + if (!(*this)(lhs->args_and_result_[i], rhs->args_and_result_[i])) { + return false; + } + } + return true; + } +} + +/* static */ +DeviceDomainPtr DeviceDomains::MakeDomain(const Type& type, DLDeviceType device_type) { + if (const auto* func_type_node = type.as()) { + std::vector args_and_result; + args_and_result.reserve(func_type_node->arg_types.size() + 1); + for (const auto& arg_type : func_type_node->arg_types) { + args_and_result.emplace_back(MakeDomain(arg_type, kInvalidDeviceType)); + } + args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, device_type)); + return std::make_shared(std::move(args_and_result)); + } else { + return std::make_shared(device_type); + } +} + +DeviceDomainPtr DeviceDomains::Lookup(DeviceDomainPtr domain) { + DeviceDomainPtr root = domain; + while (true) { + auto itr = domain_to_equiv_.find(root); + if (itr == domain_to_equiv_.end()) { + break; + } + ICHECK_NE(itr->second, root); + root = itr->second; + ICHECK_NOTNULL(root); + } + // Path compression. + while (domain != root) { + auto itr = domain_to_equiv_.find(domain); + ICHECK(itr != domain_to_equiv_.end()); + domain = itr->second; + ICHECK_NOTNULL(domain); + itr->second = root; + } + return root; +} + +DeviceDomainPtr DeviceDomains::Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + // TODO(mbs): Proper diagnostics. + ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size()) + << "Device domains:" << std::endl + << ToString(lhs) << std::endl + << "and" << std::endl + << ToString(rhs) << std::endl + << "do not have the same kind and can't be unified."; + if (rhs->is_free()) { + return lhs; + } else if (lhs->is_free()) { + return rhs; + } else if (lhs->args_and_result_.empty()) { + // Must have consistent device types for first order domains. + if (lhs->device_type_ != rhs->device_type_) { + // TODO(mbs): Proper diagnostics. + std::ostringstream os; + os << "Inconsistent device types " << lhs->device_type_ << " and " << rhs->device_type_; + throw Error(os.str()); + } + return lhs; + } else { + // Recurse for higher-order. + std::vector args_and_result; + args_and_result.reserve(lhs->args_and_result_.size()); + for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { + args_and_result.emplace_back(Unify(lhs->args_and_result_[i], rhs->args_and_result_[i])); + } + return MakeDomain(std::move(args_and_result)); + } +} + +DeviceDomainPtr DeviceDomains::Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) { + lhs = Lookup(lhs); + rhs = Lookup(rhs); + auto joined_domain = Join(lhs, rhs); + if (!DeviceDomainEqual()(lhs, joined_domain)) { + domain_to_equiv_.emplace(lhs, joined_domain); + } + if (!DeviceDomainEqual()(rhs, joined_domain)) { + domain_to_equiv_.emplace(rhs, joined_domain); + } + return joined_domain; +} + +void DeviceDomains::UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + if (!lhs->is_higher_order() && rhs->is_higher_order()) { + Collapse(lhs, rhs); + } else { + Unify(lhs, rhs); + } +} + +DeviceDomainPtr DeviceDomains::DomainFor(const Expr& expr) { + ICHECK(expr.defined()); + auto itr = expr_to_domain_.find(expr.get()); + if (itr != expr_to_domain_.end()) { + return Lookup(itr->second); + } + auto domain = Free(expr->checked_type()); + expr_to_domain_.emplace(expr.get(), domain); + return domain; +} + +DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { + auto itr = call_to_callee_domain_.find(call.get()); + if (itr != call_to_callee_domain_.end()) { + return Lookup(itr->second); + } + std::vector args_and_result; + + auto on_device_props = GetOnDeviceProps(call.get()); + auto device_copy_props = GetDeviceCopyProps(call.get()); + if (!device_copy_props.body.defined()) { + device_copy_props = GetPrimitiveDeviceCopyProps(call.get()); + } + + if (on_device_props.body.defined()) { + // on_device(expr, device_type=, is_fixed=false) + // on_device : fn():?x? + // + // on_device(expr, device_type=, is_fixed=true) + // on_device: fn(): + args_and_result.emplace_back( + ForDeviceType(on_device_props.body->checked_type(), on_device_props.device_type)); + if (on_device_props.is_fixed) { + args_and_result.emplace_back(args_and_result.front()); + } else { + args_and_result.emplace_back(Free(on_device_props.body->checked_type())); + } + } else if (device_copy_props.body.defined()) { + // device_copy(expr, src_dev_type=, dst_dev_type=) + // device_copy: fn(): + args_and_result.emplace_back( + ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.src_dev_type)); + args_and_result.emplace_back( + ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.dst_dev_type)); + } else if (call->op == alloc_storage_op) { + ICHECK_EQ(call->args.size(), 2U); + // alloc_storage(size, alignment, device_type=) + // alloc_storage: fn(, ): + const auto* attrs = call->attrs.as(); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back( + ForDeviceType(call->checked_type(), static_cast(attrs->device_type))); + } else if (call->op == alloc_tensor_op) { + ICHECK_EQ(call->args.size(), 3U); + // alloc_tensor(storage, offset, shape) + // alloc_tensor: fn(?x?, , ):?x? + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(free_domain); + } else if (call->op == shape_func_op) { + ICHECK_EQ(call->args.size(), 3U); + // shape_func(func, inputs, outputs, is_inputs=[...]) + // shape_func: fn(..., , ): + // where ... is a free domain appropriate for func's type + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + // TODO(mbs): I think this should be on the cpu only when is_input = [false], but + // what do we do when we have multiple arguments with different is_input values? + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + } else if (call->op == shape_of_op) { + ICHECK_EQ(call->args.size(), 1U); + // shape_of(tensor) + // shape_of: fn(?x?): + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + args_and_result.emplace_back(cpu_domain_); + } else if (call->op == invoke_tvm_op) { + ICHECK_EQ(call->args.size(), 3U); + // invoke_tvm_op(op, inputs, outputs) + // invoke_tvm_op: fn(..., ?x?, ?x?):?x? + // where ... is a free domain appropriate for op's type + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(free_domain); + } else if (call->op == reshape_tensor_op) { + ICHECK_EQ(call->args.size(), 2U); + // reshape_tensor(data, shape) + // reshape_tensor: fn(?x?, ):?x? + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(free_domain); + } else if (call->op->IsInstance()) { + // (arg1, ..., argn) + // : fn(?x?, ..., ?x?):?x? + // (all args and result must be first-order). + auto free_domain = Free(arb_); + for (size_t i = 0; i < call->args.size(); ++i) { + args_and_result.emplace_back(free_domain); + } + args_and_result.emplace_back(free_domain); + } else if (call->op->IsInstance()) { + // (arg1, ..., argn) + // : fn(?x1?, ..., ?xn?):?xr? + // where we force all possibly higher-order ?xi? to be collapsed to the first-order ?xr?. + // TODO(mbs): This assumes we've eta-expanded constructors, thus all constructors appear + // in callee positions. + const auto* func_type_node = call->op->checked_type().as(); + ICHECK_NOTNULL(func_type_node); + ICHECK_EQ(func_type_node->arg_types.size(), call->args.size()); + auto result_domain = Free(func_type_node->ret_type); // first-order + for (const auto& arg_type : func_type_node->arg_types) { + auto param_domain = Free(arg_type); // possibly higher-order + UnifyCollapsed(result_domain, param_domain); // collapse if required + args_and_result.emplace_back(param_domain); + } + args_and_result.emplace_back(result_domain); + } else { + // Defer to normal case where op can be an arbitrary expression. + return DomainFor(call->op); + } + auto domain = MakeDomain(std::move(args_and_result)); + call_to_callee_domain_.emplace(call.get(), domain); + return domain; +} + +void DeviceDomains::UnifyExprExact(const Expr& lhs, const Expr& rhs) { + auto lhs_domain = DomainFor(lhs); + auto rhs_domain = DomainFor(rhs); + try { + Unify(lhs_domain, rhs_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expressions:" << std::endl + << PrettyPrint(lhs) << std::endl + << "with device:" << std::endl + << ToString(lhs_domain) << "and:" << std::endl + << PrettyPrint(rhs) << std::endl + << "with device:" << std::endl + << ToString(rhs_domain) << std::endl + << e.what(); + } +} + +void DeviceDomains::UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain) { + auto actual_domain = DomainFor(expr); + try { + Unify(actual_domain, expected_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expression:" << std::endl + << PrettyPrint(expr) << std::endl + << "with actual device:" << std::endl + << ToString(actual_domain) << std::endl + << "and expected device:" << std::endl + << ToString(expected_domain) << std::endl + << e.what(); + } +} + +void DeviceDomains::UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain) { + auto actual_domain = DomainFor(expr); + try { + UnifyCollapsed(actual_domain, expected_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expression:" << std::endl + << PrettyPrint(expr) << std::endl + << "with actual device:" << std::endl + << ToString(actual_domain) << std::endl + << "and expected device:" << std::endl + << ToString(expected_domain) << std::endl + << e.what(); + } +} + +bool DeviceDomains::AnyFree(DeviceDomainPtr domain) { + domain = Lookup(domain); + if (domain->is_free()) { + return true; + } + for (const auto& sub_domain : domain->args_and_result_) { + if (AnyFree(sub_domain)) { + return true; + } + } + return false; +} + +void DeviceDomains::Collapse(const DeviceDomainPtr& first_order_domain, + const DeviceDomainPtr& higher_order_domain) { + for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) { + Unify(higher_order_domain->function_param(i), first_order_domain); + } + Unify(higher_order_domain->function_result(), first_order_domain); +} + +void DeviceDomains::SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type) { + ICHECK_NE(default_device_type, kInvalidDeviceType); + domain = Lookup(domain); + if (domain->is_free()) { + // Will never throw since lhs is free. + Unify(domain, std::make_shared(default_device_type)); + } else if (!domain->args_and_result_.empty()) { + for (const auto& sub_domain : domain->args_and_result_) { + SetDefault(sub_domain, default_device_type); + } + } +} + +void DeviceDomains::SetResultDefaultThenParams(const DeviceDomainPtr& domain, + DLDeviceType default_device_type) { + if (!domain->is_higher_order()) { + SetDefault(domain, default_device_type); + return; + } + DLDeviceType result_device_type = ResultDeviceType(domain); + if (result_device_type == kInvalidDeviceType) { + // If the function result device is still free use the given default. + result_device_type = default_device_type; + } + // Default any remaining free parameters to the function result device. + SetDefault(domain, result_device_type); +} + +std::string DeviceDomains::ToString(DeviceDomainPtr domain) { + domain = Lookup(domain); + std::ostringstream os; + if (domain->is_free()) { + // first-order free + os << "?" << static_cast(reinterpret_cast(domain.get())) << "?"; + } else if (domain->args_and_result_.empty()) { + // first-order bound + os << "<" << domain->device_type_ << ">"; + } else { + // higher-order + os << "fn("; + for (size_t i = 0; i + 1 < domain->args_and_result_.size(); ++i) { + if (i > 0) { + os << ","; + } + os << ToString(domain->args_and_result_[i]); + } + os << "):" << ToString(domain->args_and_result_.back()); + } + return os.str(); +} + +std::string DeviceDomains::ToString() { + std::ostringstream os; + for (const auto& pair : expr_to_domain_) { + os << "expression:" << std::endl + << PrettyPrint(GetRef(pair.first)) << std::endl + << "domain:" << std::endl + << ToString(pair.second) << std::endl + << std::endl; + } + for (const auto& pair : call_to_callee_domain_) { + os << "call:" << std::endl + << PrettyPrint(GetRef(pair.first)) << std::endl + << "callee domain:" << std::endl + << ToString(pair.second) << std::endl + << std::endl; + } + return os.str(); +} + +DeviceDomainPtr DeviceDomains::ResultDomain(DeviceDomainPtr domain) { + domain = Lookup(domain); + while (!domain->args_and_result_.empty()) { + domain = Lookup(domain->args_and_result_.back()); + } + return domain; +} + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/device_domains.h b/src/relay/transforms/device_domains.h new file mode 100644 index 000000000000..a29370a0e807 --- /dev/null +++ b/src/relay/transforms/device_domains.h @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/device_domains.h + * \brief Unification domain for the device planner. + */ + +#ifndef TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_ +#define TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace transform { + +class DeviceDomain; +using DeviceDomainPtr = std::shared_ptr; + +/*! + * \brief Represents the domain over which we collect equality constraints. + * + * \code + * D ::= ?x? -- first order, free + * | -- first order, bound + * | fn(D1, ..., Dn):Dr -- higher order + * \endcode + * + * We require a function value to be on the same device as its result. To support that we need + * a notion of the 'result domain' of a domain: + * \code + * result_domain(?x?) = ?x? + * result_domain() = + * result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr) + * \endcode + */ +class DeviceDomain { + public: + /*! + * \brief Constructs a first-order domain of \p device_type, which may be + * \p kInvalidDeviceType to indicate the domain is free. + */ + explicit DeviceDomain(DLDeviceType device_type) : device_type_(device_type) {} + + /*! + * \brief Constructs a higher-order domain, where \p args_and_result contain the + * function argument and result domains in order. + */ + explicit DeviceDomain(std::vector args_and_result) + : device_type_(kInvalidDeviceType), args_and_result_(std::move(args_and_result)) {} + + /*! \brief Returns true if domain is first-order and free. */ + bool is_free() const { return device_type_ == kInvalidDeviceType && args_and_result_.empty(); } + + /*! \brief Returns true if domain is higher-order. */ + bool is_higher_order() const { return !args_and_result_.empty(); } + + DLDeviceType first_order_device_type() const { + ICHECK(args_and_result_.empty()); + return device_type_; + } + + size_t function_arity() const { + ICHECK(!args_and_result_.empty()); + return args_and_result_.size() - 1UL; + } + + DeviceDomainPtr function_param(size_t i) const { + ICHECK(!args_and_result_.empty()); + ICHECK_LT(i + 1, args_and_result_.size()); + return args_and_result_[i]; + } + + DeviceDomainPtr function_result() const { + ICHECK(!args_and_result_.empty()); + return args_and_result_.back(); + } + + private: + /*! + * \brief If this is a function domain then always kInvalidDevice. Otherwise will be + * kInvalidDevice if the domain is still free, or the specific concrete device if the domain is + * bound. + */ + const DLDeviceType device_type_; + + /*! + * \brief If this is a function domain then the sub-domains for each of the function's + * arguments, and the domain for its result. Otherwise empty. + */ + const std::vector args_and_result_; + + friend struct DeviceDomainHash; + friend struct DeviceDomainEqual; + friend class DeviceDomains; +}; + +// The following hash and equality helpers give each free first-order domain pointer its own +// distinct identity. +struct DeviceDomainHash { + size_t operator()(const DeviceDomainPtr& domain) const; +}; + +struct DeviceDomainEqual { + public: + bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const; +}; + +/*! + * \brief Tracks the device domains for a set of expressions w.r.t. an equivalence relation + * built up by calls to \p Unify. + */ +class DeviceDomains { + public: + DeviceDomains() = default; + + /*! + * \brief Returns a domain appropriate for \p type who's result domain is bound + * to \p device_type. If \p device_type is \p kInvalidDeviceType then the entire domain + * will be free. + */ + static DeviceDomainPtr MakeDomain(const Type& type, DLDeviceType device_type); + + /*! + * \brief Returns a higher-order domain with \p args_and_results. + */ + static DeviceDomainPtr MakeDomain(std::vector arg_and_results) { + return std::make_shared(std::move(arg_and_results)); + } + + /*! \brief Returns a domain with the given result device type appropriate \p device_type. */ + static DeviceDomainPtr ForDeviceType(const Type& type, DLDeviceType device_type) { + ICHECK_NE(device_type, kInvalidDeviceType); + return MakeDomain(type, device_type); + } + + /*! \brief Returns a free domain appropriate for \p type. */ + static DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, kInvalidDeviceType); } + + /*! \brief Returns the domain representing the equivalence class containing \p domain. */ + DeviceDomainPtr Lookup(DeviceDomainPtr domain); + + /*! + * \brief Returns the domain accounting for all bound devices in \p lhs and \p rhs. + * + * Throws \p Error on failure. + */ + DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); + + /*! + * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Fails if \p lhs and \p + * rhs disagree on bound device type. + * + * Throws \p Error on failure. + */ + // TODO(mbs): I don't think we need an occurs check since the program is well-typed, but + // given we have refs to functions I'm prepared to be surprised. + DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs); + + /*! + * \brief Unifies \p lhs and \p rhs. If \p lhs is first-order and \p rhs is higher-order, + * require all arguments and result of \p rhs to unify with \p lhs. Otherwise same as + * \p Unify. + * + * Throws \p Error on failure. + */ + void UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); + + /*! \brief Returns true if a domain is known for \p expr. */ + bool contains(const Expr& expr) const { return expr_to_domain_.count(expr.get()); } + + /*! \brief Returns the domain representing \p expr. */ + DeviceDomainPtr DomainFor(const Expr& expr); + + /*! + * \brief Returns the domain representing the callee (ie 'op') in \p call expression. If the + * callee is a primitive or special operation we handle it specially. Otherwise defers to \p + * DomainFor(call->op). + * + * This special handling is needed: + * - To handle the "on_device" and "device_copy" ops which constrain devices to the given devices. + * - To handle some special ops which constrain devices to the CPU. + * - To allow the same primitive to be called on different devices at different call sites. + * Since each call to the op can have a different domain we index the ops by the call expression + * rather than the op itself. + */ + DeviceDomainPtr DomainForCallee(const Call& call); + + /*! \brief Unifies the domains for expressions \p lhs and \p rhs. */ + void UnifyExprExact(const Expr& lhs, const Expr& rhs); + + /*! + * \brief Unifies the domain for \p expr with \p expected_domain. + */ + void UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain); + + /*! + * \brief Unifies the domain for \p expr with \p expected_domain. + * If \p expected_domain is higher-order but \p expr is first-order, require all arguments + * and the result of \p expected_domain to have the same domain as for \p expr. + */ + void UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain); + + /*! \brief Returns true if \p domain contains any free sub-domains. */ + bool AnyFree(DeviceDomainPtr domain); + + /* + * \brief Force all domains in \p higher_order_domain to unify with \p first_order_domain. + * This can be used to handle functions within tuples, references and ADTs since we don't + * attempt to track anything beyond 'the device' for expressions of those first-order types. + * + * Throws \p Error on failure. + */ + void Collapse(const DeviceDomainPtr& first_order_domain, + const DeviceDomainPtr& higher_order_domain); + + /*! \brief Force all free domains in \p domain to default to \p default_device_type. */ + void SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type); + + /*! + * \brief If \p domain is higher-order and its result domain is free, force it to + * \p default_device_type. Then force any remaining free domains to the result domain + * (freshly defaulted or original). If \p domain is first-order same as \p SetDefault. + */ + void SetResultDefaultThenParams(const DeviceDomainPtr& domain, DLDeviceType default_device_type); + + /*! \brief Returns one-line description of \p domain for debugging. */ + std::string ToString(DeviceDomainPtr domain); + + /*! \brief Returns description of entire system of constraints for debugging */ + std::string ToString(); + + /*! + * \brief Returns the result domain for \p domain (see defn in DeviceDomain comment). + */ + DeviceDomainPtr ResultDomain(DeviceDomainPtr domain); + + /*! + * \brief Returns the result (possibly free) device type for \p domain (see defn in DeviceDomain + * comment). + */ + DLDeviceType ResultDeviceType(const DeviceDomainPtr& domain) { + return ResultDomain(domain)->first_order_device_type(); + } + + private: + /*! \brief Intrinsics we need to handle specially. */ + const Op& alloc_storage_op = Op::Get("memory.alloc_storage"); + const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor"); + const Op& shape_of_op = Op::Get("vm.shape_of"); + const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op"); + const Op& shape_func_op = Op::Get("vm.shape_func"); + const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor"); + /*! \brief The CPU device type for special operators such as dynamic shape functions. */ + const DLDeviceType cpu_device_type_ = kDLCPU; + /*! \brief Placeholder for any first-order type. */ + Type arb_ = TupleType(); + /*! \brief The domain for first-order expressions on the CPU. */ + DeviceDomainPtr cpu_domain_ = ForDeviceType(arb_, cpu_device_type_); + + /*! \brief Maps expressions to their domains as determined during analysis. */ + std::unordered_map expr_to_domain_; + + /*! + * \brief Maps call expressions to the domains for their callee where the callee is a primitive. + */ + std::unordered_map call_to_callee_domain_; + + /*! \brief Maps device domains to their equivalent domains as determined during unification. */ + std::unordered_map + domain_to_equiv_; +}; + +} // namespace transform +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_ diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc new file mode 100644 index 000000000000..35bf406263e4 --- /dev/null +++ b/src/relay/transforms/device_planner.cc @@ -0,0 +1,1123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/device_planner.cc + * \brief Determines a unique device to hold the result of every Relay sub-expression. + * + * We say a Relay expression E is 'on device D' if the result of executing E is stored on D. + * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the + * specific target associated with D (this is recovered independently via a TargetMap), and we + * do not track the storage scope within D (this is yet to be implemented). + * + * Note that 'stored on device D' is almost but not quite the same as 'executes on device D', + * see below. + * + * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes: + * - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and + * 'dst_dev_type' device type, which constrain the argument and context of the call + * respectively. It is ok if source and destination devices are the same, such no-op copies + * will be removed after accounting for the device preference. + * - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which + * constrains the argument of the call, but (usually, see below) leaves the context + * unconstrained. These are called 'annotations' in the rest of the code, have no operational + * significance by themselves, but may trigger the insertion of a new "device_copy". + * - In two situations the result of an "on_device" CallNode may also be constrained to the + * given device: + * - The "on_device" call occurs at the top-level of a function body, or occurs as an + * immediately let-bound expression. In this situation the extra degree of freedom in + * the function result and let-binding leads to surprising device copies, so we simply + * force the function result or let-bound variable to the given device. + * - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted + * it ourselves during an earlier invocation of this pass. This helps make this pass + * idempotent. + * + * We proceed in four phases: + * + * Phase 0 + * ------- + * We rewrite the programs to handle some special cases: + * - "on_device" calls at the top-level of function or immediately let-bound are rewritten + * to have \code is_fixed=true \endcode. + * - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written + * \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from + * the tuple rather than project from a copy of the tuple. We'll do this by rewriting. + * + * Phase 1 + * ------- + * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see + * below) to all other Relay sub-expressions. (For idempotence we also respect any existing + * "param_device_types" and "result_device_type" function attributes we introduce below.) + * + * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the + * same device. However each call site can use a different device. In other words primitives are + * 'device polymorphic' since we compile and execute them for each required device. + * + * For most Relay expressions the device for the overall expression is the same as the device + * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple + * itself, the condition and arms of an \p if must all be on the same device as the overall if, + * and so on. + * + * Some special ops (or 'dialects') are handled: + * - Relay supports computing the shape of tensors and operators at runtime using "shape_of", + * "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, but the tensors + * they describe may reside on any device. + * - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again + * shapes reside on the CPU, but the allocated tensors may reside on any device. + * + * Two Relay expression have special handling: + * - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the + * overall let. However the result of \p e1 may be on a different device. + * - For a function \code fn(x, y) { body } \endcode the result of the function must be on the + * same device as \p body. However parameters \p x and \p may be on different devices, even + * different from each other. Every call to the function must use the same choice of parameter + * and result devices -- there is no 'device polymorphism' for Relay functions. + * + * Phase 2 + * ------- + * After flowing constraints we apply some defaulting heuristics (using a global default device) + * to fix the device for any as-yet unconstrained sub-expressions. + * - Unconstrained function result devices default to the global default device. + * - Unconstrained function parameters devices default to the device for the function result. + * - Unconstrained let-bound expression devices default to the device for the overall let. + * TODO(mbs): I may have over-innovated here and we simply want to bind all free domaints to + * the global default device. Worth a design doc with motivating examples I think. + * + * Phase 3 + * ------- + * Finally, the result of this analysis is reified into the result as: + * - Additional "param_device_types" (an Array) and "result_device_type" (Integer) + * attributes for every function (both top-level and local). These describe the devices for + * the function's parameters and the result. + * - Additional "device_copy" CallNodes where a copy is required in order to respect the + * intent of the original "on_device" CallNodes. + * - Additional "on_device" CallNodes where the device type of an expression does not match + * that of the lexically enclosing "on_device" CallNode or function attribute. In practice + * this means "on_device" CallNodes may appear in two places: + * - On a let-bound expression if its device differs from the overall let expression. + * - On a call argument if its device differs from the call result. In particular, the + * argument to a "device_copy" call will always be wrapped in an "on_device". (That may + * seem pedantic but simplifies downstream handling.) + * However since we make it easy to track devices for variables we never wrap an "on_device" + * around a var or global var. These uses of "on_device" imply both the argument and result are + * on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true, + * which helps make this pass idempotent. + * + * Helper visitors (in device_aware_visitors.h) can be used by downstream transforms to recover + * the device for any expression for their own use, e.g. during memory planning. All downstream + * passes must preserve the lexical scoping of the "on_device" CallNodes. E.g. conversion + * to ANF must respect the lexical scoping convention: + * \code + * f(on_device(g(h(a, b), c), device_type=CPU)) + * ==> + * let %x0 = on_device(h(a, b), device_type=CPU) + * let %x1 = on_device(g(%x0), device-type=CPU) + * f(on_device(%x1, device_type=CPU)) + * \endcode + * + * This pass can be run before FuseOps it can use device-specific fusion rules. + * + * 'Stored on' vs 'Executes on' + * ---------------------------- + * Obviously for a primitive call \code add(x, y) \endcode we can execute the primitive on the + * same device as will hold its result. Thus 'executes on' is the same as 'stored on' for + * primitives. + * + * But what about for arbitrary Relay expressions? Most backends (interpreter, graph, VM) are + * implicitly executed on the 'host' CPU, with only primitive evaluation handed off to specific + * devices, thus the notion of 'executes on' is mute. AOT backends on the other hand need to + * know exactly which device (possibly one of a number of available 'CPU'-like devices) is + * responsible for execution. Currently that's handled independently by the \p AnnotateTargets + * pass, but we'd like to fold that into device planning here to ensure everything is consistent. + * + * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay + * expression (eg an if expression) on one device even though the tensor data resides on + * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on' + * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just + * compile the function body for the function's result device. + * + * This works after conversion to ANF provided the compilation for a let expression is prepared + * to make a cross-device call. However we leave it to a downstream transformation to heuristically + * minimize cross-device calls by moving device copies out of functions. E.g.: + * \code + * def @f() { // execute on CPU + * let x = on_device(...GPU computation..., device_type=GPU); + * device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU) + * } + * def @main() { + * ... call @f() on CPU ... + * } + * \endcode + * could be rewritten to: + * \code + * def @f() { // execute on GPU + * let x = ...GPU computation...; + * ...GPU computation... + * } + * def @main() { + * let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU) + * ... use x on CPU ... + * } + * \endcode + * + * Higher-order shenanigans + * ------------------------ + * Relay is a 'mostly' higher-order language -- we can let-bind functions, pass functions + * as arguments (even anonymous functions), return functions, evaluate conditional expressions + * over functions, and so on. We handle this during constraint solving using the domain: + * \code + * D ::= -- first-order + * | fn(D,...,D):D -- higher-order + * \endcode + * In this way we can determine the device for all function parameters and results. E.g. for + * \code + * let f = fn(x, y) { ... } + * let g = fn(f, z) { f(z, z) } + * g(f, on_device(..., device_type=CPU)) + * \endcode + * the parameters \p x and \p y will be on the CPU. + * + * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 must evaluate to a + * function. Our analysis must guarantee that the function's parameters and result devices are + * consistent for \p e2, \p e3, and the context of the call. But: + * - Which device holds the closure result of evaluating \p e1 ? + * - If \p e2 is of function type, what does that mean when we say every function parameter + * is on a device? + * - If \p e1 returns a function, what does that mean when we say every function result is + * on a device? + * + * Since higher-order aspects are later compiled away (by 'defunctionalization' + * aka 'firstification') we'd prefer not to have to answer any of those questions. In particular, + * we really don't want our domain \p D to allow for yet another device for the function closure. + * So we'll just force the 'device for a function' to be the same as the device for the function's + * result using the notion of the 'result domain' for a domain: + * \code + * result_domain() = + * result_domain(fn(D1,...,Dn):Dr) = result_domain(Dr) + * \endcode + * + * Similarly the domain does not have entries for tuples, references, or ADTs. Whenever the + * analysis encounters a function inside one of those it simply forces all argument and result + * devices for the function to match the device for the first-order expression. For example, + * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the inner function + * parameters and result must similarly be on the GPU. + * + * ------- + * | AOR | This pass supports all of Relay. + * ------- + * ^ + * | + * `-- Mark's stamp of completeness :-) + * + * TODO(mbs): + * * Though on_device is the identity for all types we can't wrap it around functions/constructors + * taking type args (or at least not without changing type_infer.cc to see through them). + * This is not currently handled generally. + * * Proper diagnostics for unification failure using spans. + * * Make sure the pass is idempotent even after FuseOps etc. + * * Support application of constructors properly. Are they device polymorphic? + * * Replace DLDeviceType with TargetDevice, and unify 'target annotation' with 'device planning'. + * * Support running the pass post FuseOps (so need to understand primitive functions, both + * outlines and lined) and post the VM transforms (probably need to support more intrinsic + * forms?). + * * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish between the default + * device for primitives vs the default device for the rest of Relay. + * * We'll probably need some support for partial 'device polymorphism' for functions once we + * incorporate targets and memory scopes into the domain. For example it's ok for the function + * body to be executed on different device ids provided they have the same target and memory + * scope. + * * Might be simpler to just let every type have a device annotation rather than work in + * a separate domain? + * * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies. + * * The original device_annotation.cc RewriteAnnotatedOps removed all "on_device" calls + * in tuples at the top level of function bodies or main expression, irrespective of the + * "on_device" body. What's up with that? + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/annotation/annotation.h" +#include "../op/memory/device_copy.h" +#include "./device_domains.h" + +namespace tvm { +namespace relay { +namespace transform { + +namespace { + +/****** +******* Phase 0 +*******/ + +/*! + * \brief Rewrites "on_device" calls to handle some special cases. + * + * \code + * let %x = on_device(e, device_type=d) + * ==> let %x = on_device(e, device_type=d, is_fixed=True) + * + * fn(%x) { on_device(e, device_type=d) } + * ==> fn(%x) { on_device(e, device_type=d, is_fixed=True) + * + * on_device(e).0 + * ==> on_device(e.0) + * \endcode + */ +class RewriteOnDevices : public ExprMutator { + public: + RewriteOnDevices() = default; + + private: + Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + Expr tuple = VisitExpr(tuple_get_item_node->tuple); + // TODO(mbs): Avoid copy. + Expr tuple_get_item = + TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); + auto props = GetOnDeviceProps(tuple); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "wrapping tuple get item:" << std::endl + << PrettyPrint(GetRef(tuple_get_item_node)) << std::endl + << "with \"on_device\" for device " << props.device_type; + return OnDevice(tuple_get_item, props.device_type, /*is_fixed=*/false); + } else { + return tuple_get_item; + } + } + + Expr VisitExpr_(const LetNode* let_node) final { + auto expr = GetRef(let_node); + std::vector> bindings; + while (const auto* inner_let_node = expr.as()) { + Expr inner_let = GetRef(inner_let_node); + Expr value = VisitExpr(inner_let_node->value); + auto props = GetOnDeviceProps(value); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "revising let-bound expression of let:" << std::endl + << PrettyPrint(expr) << std::endl + << "to be fixed to device " << props.device_type; + value = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + } + bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); + expr = inner_let_node->body; + } + expr = VisitExpr(expr); + // TODO(mbs): Avoid copy. + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + expr = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), expr, + /*span=*/std::get<2>(*itr)); + } + return expr; + } + + Expr VisitExpr_(const FunctionNode* function_node) final { + Expr body = VisitExpr(function_node->body); + auto props = GetOnDeviceProps(body); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "revising body of function:" << std::endl + << PrettyPrint(GetRef(function_node)) << std::endl + << "to be fixed to device " << props.device_type; + body = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + } + // TODO(mbs): Avoid copy + return Function(function_node->params, body, function_node->ret_type, + function_node->type_params, function_node->attrs, function_node->span); + } +}; + +/****** +******* Phase 1 +*******/ + +/* + * \brief Collects the system of device constraints for all sub-expressions in a module. + * It is possible some devices remain free and will need to be defaulted by \p DeviceDefaulter. + * + * Eg from \code add(%x, %y) \endcode we know \p %x and \p %y must be on the same device. Later, + * from \code on_device(%x, device_type=d) \endcode we know \p %x must be on device \p d, and thus + * so must \p %y. + * + * Constraints can flow in interesting ways. E.g. in: + * \code + * let %f = fn(%x, %y) { add(%x, on_device(%y, device_type=d)) } + * let %g = fn(%f, %x, %y) { %f(%x, %y) } + * %g(%f, %a, %b) + * \endcode + * we discover \p %b must be on device \p d. + */ +class DeviceAnalyzer : public ExprVisitor { + public: + explicit DeviceAnalyzer(IRModule mod) + : mod_(std::move(mod)), domains_(std::make_unique()) {} + + /*! + * \brief Returns the expression-to-device-domain map for all expressions in all the global + * function definitions in the module. Expressions may have free domains, these will be resolved + * by \p DeviceDefaulter below. + */ + std::unique_ptr Analyze() { + VLOG_CONTEXT << "DeviceAnalyzer"; + for (const auto& pair : mod_->functions) { + VLOG(1) << "collecting constraints for '" << PrettyPrint(pair.first) << "'"; + domains_->UnifyExprExact(pair.first, pair.second); + VisitExpr(pair.second); + } + return std::move(domains_); + } + + private: + void VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + + // Find the higher-order domain for the callee. See DomainForCallee for the special rules + // for primitives. + VisitExpr(call_node->op); + auto func_domain = domains_->DomainForCallee(call); // higher-order + + // Build the domain for the function implied by its arguments and call context. + ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + std::vector args_and_result_domains; + args_and_result_domains.reserve(call_node->args.size() + 1); + for (const auto& arg : call_node->args) { + args_and_result_domains.emplace_back(domains_->DomainFor(arg)); + VisitExpr(arg); + } + args_and_result_domains.emplace_back(domains_->DomainFor(call)); + auto implied_domain = + DeviceDomains::MakeDomain(std::move(args_and_result_domains)); // higher-order + + VLOG(1) << "initial call function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and implied domain:" << std::endl + << domains_->ToString(implied_domain) << std::endl + << "for call:" << std::endl + << PrettyPrint(call); + + // The above must match. + try { + domains_->Unify(func_domain, implied_domain); // higher-order + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Function parameters and result devices do not match those of call. Call:" + << std::endl + << PrettyPrint(call) << std::endl + << "with function devices:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and implied call devices:" << std::endl + << domains_->ToString(implied_domain) << std::endl + << e.what(); + } + + VLOG(1) << "final call function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "for call:" << std::endl + << PrettyPrint(call); + } + + void VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iteratively visit let nodes to avoid stack overflow. + while (expr->IsInstance()) { + Let let = Downcast(expr); + // Let var must be same device as value it is bound to. + domains_->UnifyExprExact(let->var, let->value); // may be higher-order + // Let body must be same device as overall let. + domains_->UnifyExprExact(let, let->body); // may be higher-order + + VisitExpr(let->var); + VisitExpr(let->value); + + expr = let->body; + } + + // Visit the last body + VisitExpr(expr); + } + + void VisitExpr_(const FunctionNode* function_node) final { + // No need to step into fused primitive functions as they are lowered individually according + // to the devices of all their call sites. + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return; + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + + // The function body domain must match the function result domain. + domains_->UnifyExprExact(function_node->body, + func_domain->function_result()); // may be higher-order + + VLOG(1) << "initial function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and function body domain:" << std::endl + << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl + << "for function:" << std::endl + << PrettyPrint(function); + + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + for (size_t i = 0; i < function_node->params.size(); ++i) { + // The parameter domains must match the function argument domains. + domains_->UnifyExprExact(function_node->params[i], + func_domain->function_param(i)); // may be higher-order + VisitExpr(function_node->params[i]); + } + + // If the function already has device attributes then we can further constrain the + // function's domain to match them. + if (GetFunctionResultDeviceType(function_node) != kInvalidDeviceType) { + std::vector args_and_result; + for (size_t i = 0; i < function_node->params.size(); ++i) { + args_and_result.emplace_back( + domains_->ForDeviceType(function_node->params[i]->checked_type(), + GetFunctionParamDeviceType(function_node, i))); + } + args_and_result.emplace_back(domains_->ForDeviceType( + function_node->body->checked_type(), GetFunctionResultDeviceType(function_node))); + auto annotation_domain = domains_->MakeDomain(std::move(args_and_result)); + try { + domains_->Unify(func_domain, annotation_domain); // higher-order + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) + << "Function devices are incompatible with its \"on_device\" annotation. Function:" + << std::endl + << PrettyPrint(function) << std::endl + << "with function devices:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and annotation devices:" << std::endl + << domains_->ToString(annotation_domain) << std::endl + << e.what(); + } + } + + VisitExpr(function_node->body); + + VLOG(1) << "final function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and function body domain:" << std::endl + << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl + << "for function:" << std::endl + << PrettyPrint(function); + } + + void VisitExpr_(const TupleNode* tuple_node) final { + Tuple tuple = GetRef(tuple_node); + for (size_t i = 0; i < tuple->fields.size(); i++) { + auto domain = domains_->DomainFor(tuple->fields[i]); // may be higher-order + domains_->UnifyExprCollapsed(tuple, domain); // collapse to first-order if needed + VisitExpr(tuple->fields[i]); + } + } + + void VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + TupleGetItem tuple_get_item = GetRef(tuple_get_item_node); + auto domain = domains_->DomainFor(tuple_get_item); // may be higher-order + domains_->UnifyExprCollapsed(tuple_get_item_node->tuple, + domain); // collapse to first-order if needed + VisitExpr(tuple_get_item_node->tuple); + } + + class DevicePatternAnalyzer : public PatternVisitor { + public: + DevicePatternAnalyzer(DeviceDomains* domains, const ExprNode* adt_node) + : domains_(domains), adt_node_(adt_node) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + auto var_domain = domains_->DomainFor(pattern_var_node->var); // may be higher order + domains_->UnifyExprCollapsed(GetRef(adt_node_), + var_domain); // collapse to first-order if needed + } + + /*! \brief (Mutable borrow of) the domains for all expressions processed so far. */ + DeviceDomains* domains_; + /*! \brief The expression for the ADT we are matching over. */ + const ExprNode* adt_node_; + }; + + void VisitPattern(const Pattern& pattern) final {} + + void VisitExpr_(const MatchNode* match_node) final { + // For match node, we unify the value and the rhs of each clause + Match match = GetRef(match_node); + auto match_domain = domains_->DomainFor(match); // may be higher-order + DevicePatternAnalyzer pattern_analyzer(domains_.get(), match->data.get()); + domains_->UnifyExprCollapsed(match->data, match_domain); // collapse to first-order if needed + for (const auto& clause : match->clauses) { + pattern_analyzer.VisitPattern(clause->lhs); + domains_->UnifyExprExact(clause->rhs, match_domain); + VisitExpr(clause->rhs); + } + VisitExpr(match_node->data); + } + + void VisitExpr_(const GlobalVarNode* global_var_node) final { + domains_->DomainFor(GetRef(global_var_node)); + } + + void VisitExpr_(const VarNode* var_node) final { domains_->DomainFor(GetRef(var_node)); } + + void VisitExpr_(const ConstantNode* constant_node) final { + domains_->DomainFor(GetRef(constant_node)); + } + + void VisitExpr_(const ConstructorNode* constructor_node) final { + // no-op, constructors are handled at their call-sites. + // TODO(mbs): Assumes eta-expansion + } + + void VisitExpr_(const IfNode* if_node) final { + auto ife = GetRef(if_node); + auto domain = domains_->DomainFor(ife); // may be higher-order + domains_->UnifyExprCollapsed(if_node->cond, domain); // collapse to first-order if needed + domains_->UnifyExprExact(if_node->true_branch, domain); + domains_->UnifyExprExact(if_node->false_branch, domain); + VisitExpr(if_node->cond); + VisitExpr(if_node->true_branch); + VisitExpr(if_node->false_branch); + } + + void VisitExpr_(const OpNode* op) final { + // no-op, primitive operators are handled at their call-sites. + } + + void VisitExpr_(const RefCreateNode* ref_create_node) final { + auto ref_create = GetRef(ref_create_node); + auto domain = domains_->DomainFor(ref_create_node->value); // may be higher-order + domains_->UnifyExprCollapsed(ref_create, domain); // collapse to first-order if needed + VisitExpr(ref_create_node->value); + } + + void VisitExpr_(const RefReadNode* ref_read_node) final { + auto ref_read = GetRef(ref_read_node); + auto domain = domains_->DomainFor(ref_read); // may be higher-order + domains_->UnifyExprCollapsed(ref_read_node->ref, domain); // collapse to first-order if needed + VisitExpr(ref_read_node->ref); + } + + void VisitExpr_(const RefWriteNode* ref_write_node) final { + auto ref_write = GetRef(ref_write_node); + auto domain = domains_->DomainFor(ref_write->value); // may be higher-order + domains_->UnifyExprCollapsed(ref_write->ref, domain); // collapse to first-order if needed + domains_->UnifyExprCollapsed(ref_write, domain); // collapse to first-order if needed + VisitExpr(ref_write_node->ref); + VisitExpr(ref_write_node->value); + } + + /*! \brief The module we are analyzing. */ + IRModule mod_; + /*! \brief The domains for all expressions processed so far. */ + std::unique_ptr domains_; +}; + +/****** +******* Phase 2 +*******/ + +/*! + * \brief Ensures every sub-expression in a module has a device type, using both the global + * default and some local heuristics to avoid unnecessary additional "device_copy" CallNodes. + * + * E.g. in: + * \code + * def @main(%x, %y, %z) { + * let %a = add(%x, %y); + * multiply(%a, on_device(%z, device_type=d)) + * \endcode + * we know the parameter \p %z must be on device \p d, but the devices for \p %x and \p %y, + * and the device for the function result, are still 'free'. The global 'default' device type + * is first used to 'fix' \p @main's result type, which in turn 'fixes' \p %x and \p %y, which + * in turn 'fixes' the device on which the \p add and \p multiply are executed. + * + * TODO(mbs): I think this is deterministic? We do however visit the top-level defs in hashmap + * order. + */ +class DeviceDefaulter : public ExprVisitor { + public: + DeviceDefaulter(IRModule mod, std::unique_ptr domains, + DLDeviceType default_device_type) + : mod_(std::move(mod)), + domains_(std::move(domains)), + default_device_type_(default_device_type) {} + + std::unique_ptr Default() { + VLOG_CONTEXT << "DeviceDefaulter"; + for (const auto& pair : mod_->functions) { + VLOG(1) << "defaulting devices for '" << PrettyPrint(pair.first) << "'"; + VisitExpr(pair.second); + } + return std::move(domains_); + } + + private: + void VisitExpr_(const FunctionNode* function_node) final { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return; + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + if (domains_->AnyFree(func_domain)) { + VLOG(1) << "before defaulting function:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, default_device_type_); + VLOG(1) << "after defaulting function:" << std::endl << domains_->ToString(func_domain); + } + VisitExpr(function_node->body); + } + + void VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + auto func_domain = domains_->DomainForCallee(call); // higher-order + ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + if (domains_->AnyFree(func_domain)) { + // For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*) + // above. But for calls to primitives we may still need to force free domains to be + // defaulted. + VLOG(1) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, default_device_type_); + VLOG(1) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain); + } + return ExprVisitor::VisitExpr_(call_node); + } + + void VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iteratively visit let nodes to avoid stack overflow. + while (expr->IsInstance()) { + Let let = Downcast(expr); + // If the let-var device is still free force it to match the overall let. + auto let_domain = domains_->DomainFor(let); // may be higher-order + DLDeviceType let_device_type = domains_->ResultDeviceType(let_domain); + ICHECK_NE(let_device_type, kInvalidDeviceType); + auto let_var_domain = domains_->DomainFor(let->var); // may be higher-order + if (domains_->AnyFree(let_var_domain)) { + VLOG(1) << "before defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + domains_->SetDefault(let_var_domain, let_device_type); + VLOG(1) << "after defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + } + VisitExpr(let->var); + VisitExpr(let->value); + expr = let->body; + } + VisitExpr(expr); + } + + /*! \brief The module we are processing. */ + IRModule mod_; + /*! \brief The domains for all expressions. */ + std::unique_ptr domains_; + /*! \brief The default device type. */ + DLDeviceType default_device_type_; +}; + +/****** +******* Phase 3 +*******/ + +/*! + * \brief Inserts missing "device_copy" CallNodes, and ensures the device type of every + * sub-expression in a module can be easily recovered by a later transformation using simple + * lexical scoping rules (e.g. for memory planning). + * + * - Discard any existing "on_device" CallNodes since their job is done. Similarly, discard + * any existing "device_copy" CallNodes which are no-ops. + * + * - Functions are given "param_device_types" and "result_device_type" attributes to capture + * the device type for its parameters and result. + * + * - Additional "device_copy" CallNodes are inserted wherever there's a transition between + * storage device types. Since the DeviceAnalyzer phase succeeded this can only happen + * where the original program explicitly allowed a transition using an "on_device" CallNode. + * That is, we do not not try to 'fix' a program with inconsistent devices. + * + * - Additional "on_device" CallNodes are inserted so that a later transform can discover + * the device for an arbitrary sub-expression by looking only for the lexically enclosing + * "on_device" CallNode or "on_device" function attribute. In particular, since function + * arguments and let-bound expressions can be on a device different from the function + * or let body itself we will insert "on_device" CallNodes to spell out any differences. This + * applies even to the argument to a "device_copy" CallNode, which may look pedantic but + * keeps downstream processing simple. The "on_device" calls should be removed before code gen, + * which is easily done on-the-fly. + * + * For example, we'll end up with programs that look like: + * \code + * def @main(%x, %y, param_device_types=[...], result_device_type=...) { + * let %a = on_device(..., device_type=..., is_fixed=True) + * @f(%a, device_copy(on_device(..., device_type=..., is_fixed=True), + * src_device_type=..., dst_device_type=...)) + * } + * \endcode + */ +class DeviceCapturer : public ExprMutator { + public: + DeviceCapturer(IRModule mod, std::unique_ptr domains) + : mod_(std::move(mod)), domains_(std::move(domains)) {} + + IRModule Capture() { + VLOG_CONTEXT << "CaptureDevices"; + IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map); + for (const auto& pair : mod_->functions) { + VLOG(1) << "capturing devices for '" << PrettyPrint(pair.first) << "'"; + result->Add(pair.first, Downcast(Mutate(pair.second))); + } + return result; + } + + private: + // Nothing interesting for VarNode, ConstantNode, GlobalVarNode, OpNode and ConstructorNode + + Expr VisitExpr_(const TupleNode* tuple_node) final { + auto tuple = GetRef(tuple_node); + Array fields; + fields.reserve(tuple_node->fields.size()); + for (const auto& field : tuple_node->fields) { + fields.push_back(VisitChild(tuple, field)); + } + // TODO(mbs): Avoid copy + return Tuple(std::move(fields), tuple_node->span); + } + + Expr VisitExpr_(const FunctionNode* function_node) final { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return GetRef(function_node); + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + VLOG(1) << "capturing function:" << std::endl + << PrettyPrint(function) << std::endl + << "with domain:" << std::endl + << domains_->ToString(func_domain); + + // Gather the parameter and result device types for the function attributes. + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); + ICHECK_NE(result_device_type, kInvalidDeviceType); + Array param_device_types; + param_device_types.reserve(function_node->params.size()); + for (size_t i = 0; i < function_node->params.size(); ++i) { + DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); + ICHECK_NE(param_device_type, kInvalidDeviceType); + param_device_types.push_back(param_device_type); + } + + // Rewrite the body. Note that the body may have begun with an "on_device" so + // be prepared to insert a "device_copy". + Expr body = VisitChild( + /*lexical_device_type=*/result_device_type, + /*expected_device_type=*/result_device_type, + /*child_device_type=*/GetDeviceType(function_node->body), function_node->body); + + // TODO(mbs): Avoid copy + Function func = Function(function_node->params, body, function_node->ret_type, + function_node->type_params, function_node->attrs, function_node->span); + return FunctionOnDevice(func, param_device_types, result_device_type); + } + + Expr VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + DLDeviceType call_device_type = GetDeviceType(call); + + auto on_device_props = GetOnDeviceProps(call_node); + if (on_device_props.body.defined()) { + // We're done with the original "on_device" calls and can pinch them out. + // Note that this step has already been simulated by GetDeviceType. + return VisitExpr(on_device_props.body); + } + + auto device_copy_props = GetDeviceCopyProps(call_node); + if (device_copy_props.body.defined()) { + DLDeviceType src_device_type = device_copy_props.src_dev_type; + ICHECK_EQ(call_device_type, device_copy_props.dst_dev_type); + if (call_device_type == src_device_type) { + // We can pinch out existing "device_copy" CallNodes if their source and destinations + // match. + return VisitExpr(device_copy_props.body); + } + // else: handle as for any other call. + } + + auto func_domain = domains_->DomainForCallee(call); // higher-order + VLOG(1) << "considering call:" << std::endl + << PrettyPrint(call) << std::endl + << "on device " << call_device_type << " with function domain:" << std::endl + << domains_->ToString(func_domain); + DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); + ICHECK_NE(result_device_type, kInvalidDeviceType); + + // The callee is on the current device. + Expr op = VisitChild( + /*lexical_device_type=*/call_device_type, + /*expected_device_type=*/call_device_type, + /*child_device_type=*/result_device_type, call_node->op); + + // Each argument can be on the device for the corresponding function parameter. However if + // any of those differ from the overall call device then wrap them in an "on_device" to + // help downstream transforms track devices lexically. + Array args; + args.reserve(call_node->args.size()); + ICHECK_EQ(func_domain->function_arity(), call->args.size()); + for (size_t i = 0; i < call_node->args.size(); ++i) { + DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); + ICHECK_NE(param_device_type, kInvalidDeviceType) + << "for parameter " << i << " for call:" << std::endl + << PrettyPrint(call); + args.push_back(VisitChild(/*lexical_device_type=*/call_device_type, + /*expected_device_type=*/param_device_type, + /*child_device_type=*/GetDeviceType(call_node->args[i]), + call_node->args[i])); + } + // TODO(mbs): Avoid copy + return Call(std::move(op), std::move(args), call_node->attrs, call_node->type_args, + call_node->span); + } + + Expr VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iterate through chained lets, provided they all agree on their device type. + DLDeviceType let_device_type = GetDeviceType(expr); + std::vector> bindings; + while (const auto* inner_let_node = expr.as()) { + Expr inner_let = GetRef(inner_let_node); + if (GetDeviceType(inner_let) != let_device_type) { + // We have a device transition which needs to be handled. + break; + } + // The let-bound value can be on a different device than the overall let. However if those + // devices don't agree wrap the let-bound value in an "on_device" to help downstream + // transforms track devices lexically. + Expr value = VisitChild(/*lexical_device_type=*/let_device_type, + /*expected_device_type=*/GetDeviceType(inner_let_node->var), + /*child_device_type=*/GetDeviceType(inner_let_node->value), + inner_let_node->value); + bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); + expr = inner_let_node->body; + } + Expr body = VisitChild(/*lexical_device_type=*/let_device_type, + /*expected_device_type=*/let_device_type, + /*child_device_type=*/GetDeviceType(expr), expr); + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + body = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), body, + /*span=*/std::get<2>(*itr)); + } + return body; + } + + Expr VisitExpr_(const IfNode* if_node) final { + auto ife = GetRef(if_node); + Expr cond = VisitChild(ife, if_node->cond); + Expr true_branch = VisitChild(ife, if_node->true_branch); + Expr false_branch = VisitChild(ife, if_node->false_branch); + // TODO(mbs): Avoid copy + return If(cond, true_branch, false_branch, if_node->span); + } + + Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + auto tuple_get_item = GetRef(tuple_get_item_node); + Expr tuple = VisitChild(tuple_get_item, tuple_get_item_node->tuple); + // TODO(mbs): Avoid copy + return TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); + } + + Expr VisitExpr_(const RefCreateNode* ref_create_node) final { + auto ref_create = GetRef(ref_create_node); + Expr value = VisitChild(ref_create, ref_create_node->value); + // TODO(mbs): Avoid copy + return RefCreate(value, ref_create_node->span); + } + + Expr VisitExpr_(const RefReadNode* ref_read_node) final { + auto ref_read = GetRef(ref_read_node); + Expr ref = VisitChild(ref_read, ref_read_node->ref); + // TODO(mbs): Avoid copy + return RefRead(ref, ref_read_node->span); + } + + Expr VisitExpr_(const RefWriteNode* ref_write_node) final { + auto ref_write = GetRef(ref_write_node); + Expr ref = VisitChild(ref_write, ref_write_node->ref); + Expr value = VisitChild(ref_write, ref_write_node->value); + // TODO(mbs): Avoid copy + return RefWrite(ref, value, ref_write_node->span); + } + + Expr VisitExpr_(const MatchNode* match_node) final { + auto match = GetRef(match_node); + Expr data = VisitChild(match, match_node->data); + Array clauses; + clauses.reserve(match_node->clauses.size()); + for (const auto& clause : match_node->clauses) { + Pattern lhs = VisitPattern(clause->lhs); // actually a no-op, so we're not checking vars + Expr rhs = VisitChild(match, clause->rhs); + clauses.push_back(Clause(lhs, rhs)); + } + // TODO(mbs): Avoid copy + return Match(data, std::move(clauses), match_node->complete, match_node->span); + } + + DLDeviceType GetDeviceType(const Expr& expr) { + // Look through any "on_device" CallNodes, to mimic how we will be pinching them out. + auto props = GetOnDeviceProps(expr); + Expr true_expr = props.body.defined() ? props.body : expr; + ICHECK(domains_->contains(true_expr)); + // If expr is higher order we'll return only the result domain's device type. + DLDeviceType device_type = domains_->ResultDeviceType(domains_->DomainFor(true_expr)); + ICHECK_NE(device_type, kInvalidDeviceType) + << "no device type was determined for expression:" << std::endl + << PrettyPrint(true_expr); + return device_type; + } + + /*! + * \brief Reconcile the \p child_device_type for \p child with both the \p expected_device_type + * (as required by the expression context the \p child is in) and the \p lexical_device_type + * (as a downstream transform would infer based only on lexically enclosing "on_device" + * CallNodes and function attributes.) Generally \p lexical_device_type and \p + * expected_device_type are the same by definition, but may differ in arguments to functions + * and let-bound expressions. + * + * If \p child_device_type differs from \p expected_device_type, wrap it as: + * \code + * device_copy(on_device(child', device_type=child_device_type), + * src_dev_type=child_device_type, dst_dev_type=expected_device_type) + * \endcode + * (where child is rewritten to child'). Note the pedantic spelling out of "on_device" on the + * child. + * + * If \p expected_device_type differs from \p lexical_device_type, then (also) wrap + * the expression as: + * \code + * on_device(..., device_type=expected_device_type) + * \endcode + * + * TODO(mbs): There's no attempt at sharing here. If usage of child's node could be wrapped + * by a "device_copy", even though those copies will generally all be to the same destination + * device. + */ + Expr VisitChild(DLDeviceType lexical_device_type, DLDeviceType expected_device_type, + DLDeviceType child_device_type, const Expr& child) { + ICHECK_NE(lexical_device_type, kInvalidDeviceType); + ICHECK_NE(expected_device_type, kInvalidDeviceType); + if (child->IsInstance()) { + // Primitive operators don't need to be rewritten and can have a different domain for + // each call site. + return child; + } + Expr result = VisitExpr(child); + if (child_device_type != expected_device_type) { + VLOG(1) << "creating " << DeviceCopyOp()->name << " from device type " << child_device_type + << " to device type " << expected_device_type << " for:" << std::endl + << PrettyPrint(result); + // Also wrap the child in an "on_device" so downstream transforms can track devices + // lexically. + result = MaybeOnDevice(result, child_device_type, /*is_fixed=*/true); + result = DeviceCopy(result, child_device_type, expected_device_type); + } + if (expected_device_type != lexical_device_type) { + VLOG(1) << "creating " << OnDeviceOp()->name << " for device type " << expected_device_type + << " for:" << std::endl + << PrettyPrint(result); + result = MaybeOnDevice(result, expected_device_type, /*is_fixed=*/true); + } + return result; + } + + /*! + * Common case of visiting a direct \p child of \p parent where by default the \p child + * is expected to be on the same device as the \p parent. + */ + Expr VisitChild(const Expr& parent, const Expr& child) { + DLDeviceType expected_device_type = GetDeviceType(parent); + DLDeviceType child_device_type = GetDeviceType(child); + return VisitChild(expected_device_type, expected_device_type, child_device_type, child); + } + + /*! \brief Module we are rewriting, so we can lookup global variables. */ + IRModule mod_; + /*! \brief Device domain for every expression from DeviceAnalyzer. */ + std::unique_ptr domains_; +}; + +/*! \brief Rewrite the "on_device" calls (and implicitly re-type-check). */ +tvm::transform::Pass Rewrite() { + auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) { + return Downcast(RewriteOnDevices().Mutate(f)); + }; + return tvm::relay::transform::CreateFunctionPass(pass_func, 0, "PlanDevicesRewrite", {}); +} + +/*! \brief Run the remaining phases. */ +tvm::transform::Pass PlanDevicesCore(DLDeviceType default_device_type) { + return tvm::transform::CreateModulePass( + [=](IRModule mod, tvm::transform::PassContext pass_cnxt) -> IRModule { + // Collect the system of constraints for every sub-expression using existing "on_device" + // and "device_copy" calls. + std::unique_ptr domains = DeviceAnalyzer(mod).Analyze(); + VLOG(1) << "Domains after analysis:" << std::endl << domains->ToString(); + + // Choose sensible default devices for every sub-expression if otherwise unconstrained + // by existing "on_device" or "device_copy" calls. + domains = DeviceDefaulter(mod, std::move(domains), default_device_type).Default(); + VLOG(1) << "Domains after defaulting: " << std::endl << domains->ToString(); + + // Insert "device_copy" and "on_device" CallNodes where needed to unambiguously capture + // the above map, and attach additional "param_device_types" and "result_device_type" + // attributes to all function definitions. + return DeviceCapturer(mod, std::move(domains)).Capture(); + }, + /*opt_level=*/0, "PlanDevicesCore", {}); +} + +} // namespace + +/****** +******* Overall composite Pass +*******/ + +// This function is declared in the public . +TVM_DLL tvm::transform::Pass PlanDevices(DLDeviceType default_device_type) { + std::vector passes; + passes.emplace_back(Rewrite()); + passes.emplace_back(PlanDevicesCore(default_device_type)); + return tvm::transform::Sequential(std::move(passes), "PlanDevices"); +} + +TVM_REGISTER_GLOBAL("relay._transform.PlanDevices") + .set_body_typed([](const Device& default_device) { + return PlanDevices(default_device.device_type); + }); + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/tests/cpp/relay/relay/transforms/device_domains_test.cc b/tests/cpp/relay/relay/transforms/device_domains_test.cc new file mode 100644 index 000000000000..8f263c3b3273 --- /dev/null +++ b/tests/cpp/relay/relay/transforms/device_domains_test.cc @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Just a smoke test for the device planner's unification domain, mostly to tease out how we'd + * like to organize our cpp unit tests for functionality that's not obviously a Pass or should + * be exposed via FFI. + */ + +// TODO(mbs): Revisit cpp unit test layout or setup include dir at root of src/ +#include "../../../src/relay/transforms/device_domains.h" + +#include +#include +#include + +namespace tvm { +namespace relay { +namespace transform { +namespace { + +IRModule TestModule() { + return InferType()(tvm::parser::ParseModule("test", R"( + #[version = "0.0.5"] + def @f(%x : Tensor[(3, 7), float32], %y : Tensor[(3, 7), float32]) { + add(%x, %y) + } + )")); +} + +TEST(DeviceDomains, SmokeTest) { + DeviceDomains domains; + IRModule mod = TestModule(); + Function f = Downcast(mod->Lookup("f")); + + DeviceDomainPtr actual_add_domain = domains.DomainForCallee(Downcast(f->body)); + DeviceDomainPtr x_domain = domains.DomainFor(f->params[0]); + DeviceDomainPtr y_domain = domains.DomainFor(f->params[1]); + DeviceDomainPtr result_domain = DeviceDomains::Free(f->ret_type); + std::vector arg_and_results; + arg_and_results.push_back(x_domain); + arg_and_results.push_back(y_domain); + arg_and_results.push_back(result_domain); + DeviceDomainPtr implied_add_domain = DeviceDomains::MakeDomain(std::move(arg_and_results)); + domains.Unify(actual_add_domain, implied_add_domain); + domains.Unify(x_domain, DeviceDomains::ForDeviceType(f->params[0]->checked_type(), kDLCUDA)); + + EXPECT_EQ(domains.ResultDeviceType(y_domain), kDLCUDA); + EXPECT_EQ(domains.ResultDeviceType(result_domain), kDLCUDA); +} + +} // namespace +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py new file mode 100644 index 000000000000..2252d8a235c9 --- /dev/null +++ b/tests/python/relay/test_pass_plan_devices.py @@ -0,0 +1,1320 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License + + +"""Unit tests for the PlanDevices pass. We check: + - The pass alone given the expected AST, though we need to manually run InferTypes. + - The pass is idempotent. + - Execution on the VM backend yields the correct result.""" + +import tvm +from tvm import relay +import tvm.testing +import numpy as np + +CPU = tvm.device("cpu") # device_type=1 +GPU = tvm.device("cuda") # device_type=2 +DEFAULT = GPU + +core = tvm.IRModule() +core.import_from_std("core.rly") + + +def rewrite_and_assert(in_mod, expected_mod): + """Manually run the pass and assert it's structurally equals to the expected.""" + actual_mod = relay.transform.InferType()(in_mod) + actual_mod = relay.transform.PlanDevices(DEFAULT)(actual_mod) + actual_mod = relay.transform.InferType()(actual_mod) + expected_mod = relay.transform.InferType()(expected_mod) + if not tvm.ir.structural_equal(actual_mod, expected_mod, True): + # Print everything in full so we can see what's going on when things fail. + print("Input module:") + print(in_mod) + print("Expected module:") + print(expected_mod) + print("Actual module:") + print(actual_mod) + # Assert again so as to see the actual disagreeing sub-expressions. + tvm.ir.assert_structural_equal(actual_mod, expected_mod, True) + + +def eval_and_assert(in_mod: tvm.IRModule, reference_func, args): + """Test the standard compilation flow gives us a function which agrees with the Numpy + reference implementation.""" + if not tvm.runtime.enabled("cuda"): + print("Not evaluating since GPU is not available") + return + with tvm.transform.PassContext(opt_level=3): + compiled = relay.create_executor("vm", mod=in_mod, device=GPU, target="cuda").evaluate() + actual = compiled(*args).numpy() + expected = reference_func(*args) + tvm.testing.assert_allclose(actual, expected) + + +def rand(shape): + return np.random.rand(*shape).astype("float32") + + +def rands(shape, n): + return [rand(shape) for i in range(n)] + + +def exercise(in_mod: tvm.IRModule, expected_mod: tvm.IRModule, reference_func, args): + """Test in_mod against expected_mod and reference_func using args.""" + # Correctness + rewrite_and_assert(in_mod, expected_mod) + # Idempotence + rewrite_and_assert(expected_mod, expected_mod) + # The VM can compile and possibly even run the module + # TODO(mbs): Disabled until VM supports new device planning. + # if not (reference_func is None) and not (args is None): + # eval_and_assert(in_mod, reference_func, args) + + +def test_plain(): + # Everything defaults to GPU + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = add(%c, %d); + subtract(%0, %1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[2, 2, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + %1 = add(%c, %d); + subtract(%0, %1) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_left_add_on_cpu(): + # Force some args to be on CPU, rest default to GPU. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = add(%c, %d); + subtract(%2, %3) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_left_add_on_cpu_via_copy(): + # As for test_left_add_on_cpu, but with an explicit device_copy. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = device_copy(%0, src_dev_type=1, dst_dev_type=2); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = add(%c, %d); + subtract(%2, %3) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_both_adds_on_cpu(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = add(%c, %d); + %2 = on_device(%0, device_type=1); + %3 = on_device(%1, device_type=1); + subtract(%2, %3) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1, 1], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = add(%c, %d); + %3 = on_device(%2, device_type=1, is_fixed=True); + %4 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + subtract(%4, %5) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_sharing(): + # The same add sub-expression is annotated twice. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1); + %2 = on_device(%0, device_type=1); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = on_device(%0, device_type=1, is_fixed=True); + %3 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %4 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + subtract(%3, %4) + } + """ + ) + + def ref(a, b): + x = np.add(a, b) + return np.subtract(x, x) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_let_on_cpu(): + # The device for a let-bound expression can flow from uses of the let-bound var. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + let %l = add(%a, %b); + let %r = add(%c, %d); + %0 = on_device(%l, device_type=1); + subtract(%0, %r) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + let %l = on_device(%0, device_type=1, is_fixed=True); + let %r = add(%c, %d); + %1 = device_copy(%l, src_dev_type=1, dst_dev_type=2); + subtract(%1, %r) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_func_param_on_cpu(): + # Devices for function parameters flow to call sites. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + let %f = fn (%x, %y) { + %0 = add(%x, %y); + on_device(%0, device_type=1) + }; + %1 = %f(%a, %b); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1, 1], result_device_type=1) { + let %f = fn (%x, %y, param_device_types=[1, 1], result_device_type=1) { + add(%x, %y) + }; + %0 = %f(%a, %b); + %1 = add(%c, %d); + subtract(%0, %1) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_func_result_on_cpu(): + # Devices for call sites flow to function results. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + let %f = fn (%x, %y) { + add(%x, %y) + }; + %0 = %f(%a, %b); + %1 = on_device(%0, device_type=1); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = fn (%x, %y, param_device_types=[1, 1], result_device_type=1) { + add(%x, %y) + }; + let %f = on_device(%0, device_type=1, is_fixed=True); + %1 = %f(%a, %b); + %2 = on_device(%1, device_type=1, is_fixed=True); + %3 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + %4 = add(%c, %d); + subtract(%3, %4) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_higher_order(): + # The constraint on %a flows back to %y via %f and %h + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + let %f = fn (%g) { + fn (%a) { + %0 = on_device(%a, device_type=1); + %1 = %g(%0); + add(%1, %x) + } + }; + let %h = fn (%b) { + negative(%b) + }; + %2 = %f(%h); + %3 = %2(%y); + subtract(%x, %3) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[2, 1], result_device_type=2) { + let %f = fn (%g, param_device_types=[2], result_device_type=2) { + fn (%a, param_device_types=[1], result_device_type=2) { + %0 = device_copy(%a, src_dev_type=1, dst_dev_type=2); + %1 = %g(%0); + add(%1, %x) + } + }; + let %h = fn (%b, param_device_types=[2], result_device_type=2) { + negative(%b) + }; + %2 = %f(%h); + %3 = %2(%y); + subtract(%x, %3) + } + """ + ) + + def ref(x, y): + def f(g): + return lambda a: np.add(g(a), x) + + def h(b): + return np.negative(b) + + return np.subtract(x, f(h)(y)) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_function_in_tuple(): + # Since %f ends up in a tuple its argument and result is forced to be on the CPU + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { + %0 = on_device(%b, device_type=1); + add(%a, %0) + }; + let %t = (%f, %x); + %1 = %t.1; + %2 = %t.0; + %2(%1, %y) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=1) { + let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=1) { + add(%a, %b) + }; + let %t = (%f, %x); + %0 = %t.1; + %1 = %t.0; + %1(%0, %y) + } + """ + ) + + def ref(x, y): + return np.add(x, y) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_device_copy(): + const = rand((5, 7)) + metatable = {"relay.Constant": [relay.const(const)]} + + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32]) { + %0 = device_copy(%x, src_dev_type=1, dst_dev_type=2); + add(%0, meta[relay.Constant][0]) + } + """, + "from_string", + None, + metatable, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], param_device_types=[1], result_device_type=2) { + %0 = device_copy(%x, src_dev_type=1, dst_dev_type=2); + add(%0, meta[relay.Constant][0]) + } + """, + "from_string", + None, + metatable, + ) + + def ref(x): + return np.add(x, const) + + exercise(input(), expected(), ref, rands((5, 7), 1)) + + +def test_shape_func(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64]) { + %0 = fn (%y: Tensor[(?), float32]) { + nn.relu(%y) + }; + let %p = on_device(%0, device_type=2, is_fixed=True); + %1 = on_device(%x, device_type=2, is_fixed=True); + %2 = vm.shape_of(%1, dtype="int64"); + %3 = (%2,); + %4 = (%s,); + vm.shape_func(%p, %3, %4, is_input=[False]) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64], + param_device_types=[2, 1], result_device_type=1) { + %0 = fn (%y: Tensor[(?), float32], param_device_types=[2], result_device_type=2) { + nn.relu(%y) + }; + let %p = on_device(%0, device_type=2, is_fixed=True); + %1 = vm.shape_of(%x, dtype="int64"); + %2 = (%1,); + %3 = (%s,); + vm.shape_func(%p, %2, %3, is_input=[False]) + } + """ + ) + + # Don't try to execute, too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_shape_of(): + # We need to use is_fixed=True in the on_device call so that the tensor will be on the GPU. Otherwise the + # result defaults to the result device for @main which is the CPU, thus forcing a copy. + # TODO(mbs): Perhaps the defaulting heuristics are being too clever? + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?, ?), float32]) { + %0 = on_device(%x, device_type=2, is_fixed=True); + vm.shape_of(%0, dtype="int64") + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?, ?), float32], param_device_types=[2], result_device_type=1) { + vm.shape_of(%x, dtype="int64") + } + """ + ) + + def ref(x): + return x.shape + + exercise(input(), expected(), ref, rands((5, 7), 1)) + + +def test_alloc_storage(): + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%size: int64, %alignment: int64) { + memory.alloc_storage(%size, %alignment, device_id=0, device_type=2) + } + """, + "from_string", + core, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%size: int64, %alignment: int64, param_device_types=[1, 1], result_device_type=2) { + memory.alloc_storage(%size, %alignment, device_id=0, device_type=2) + } + """, + "from_string", + core, + ) + + # Don't try to execute, too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_alloc_tensor(): + shape = np.array([3, 2]) + metatable = {"relay.Constant": [relay.const(shape, dtype="int64")]} + + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%sto: Storage[]) { + memory.alloc_tensor(%sto, 0, meta[relay.Constant][0], + const_shape=meta[relay.Constant][0], assert_shape=[]) + } + """, + "from_string", + core, + metatable, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%sto: Storage[], param_device_types=[2], result_device_type=2) { + %0 = on_device(0, device_type=1, is_fixed=True); + %1 = on_device(meta[relay.Constant][0], device_type=1, is_fixed=True); + memory.alloc_tensor(%sto, %0, %1, const_shape=meta[relay.Constant][0], assert_shape=[]) + } + """, + "from_string", + core, + metatable, + ) + + # Don't try to execute, too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_reshape_tensor(): + newshape = [2, 4, 2] + metatable = {"relay.Constant": [relay.const(newshape, dtype="int64")]} + + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(2, 8), float32]) { + vm.reshape_tensor(%x, meta[relay.Constant][0], newshape=[2, 4, 2]) + } + """, + "from_string", + None, + metatable, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(2, 8), float32], param_device_types=[2], result_device_type=2) { + %0 = on_device(meta[relay.Constant][0], device_type=1, is_fixed=True); + vm.reshape_tensor(%x, %0, newshape=[2, 4, 2]) + } + """, + "from_string", + None, + metatable, + ) + + def ref(x): + return np.reshape(x, newshape) + + exercise(input(), expected(), ref, rands((2, 8), 1)) + + +def test_dynamic_input(): + # There's nothing special about inferring devices for partially unknown types. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x0: Tensor[(?, ?), float32], %x1: Tensor[(?, ?), float32]) { + add(%x0, %x1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x0: Tensor[(?, ?), float32], %x1: Tensor[(?, ?), float32], + param_device_types=[2, 2], result_device_type=2) { + add(%x0, %x1) + } + """ + ) + + def ref(x0, x1): + return np.add(x0, x1) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_redundant_annotation(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=1); + %2 = subtract(%1, %z); + %3 = on_device(%0, device_type=1); + add(%2, %3) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2], result_device_type=2) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = on_device(%0, device_type=1, is_fixed=True); + %4 = subtract(%2, %z); + %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + add(%4, %5) + } + """ + ) + + def ref(x, y, z): + a = np.add(x, y) + return np.add(np.subtract(a, z), a) + + exercise(input(), expected(), ref, rands((5, 7), 3)) + + +def test_annotate_expr(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2); + %2 = subtract(%1, %z); + on_device(%2, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[2, 2, 1], result_device_type=1) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2, is_fixed=True); + %2 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + subtract(%2, %z) + } + """ + ) + + def ref(x, y, z): + return np.subtract(np.add(x, y), z) + + exercise(input(), expected(), ref, rands((5, 7), 3)) + + +def test_annotate_all(): + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=1); + %2 = subtract(%1, %z); + on_device(%2, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1], result_device_type=1) { + %0 = add(%x, %y); + subtract(%0, %z) + } + """ + ) + + def ref(x, y, z): + return np.subtract(np.add(x, y), z) + + exercise(input(), expected(), ref, rands((5, 7), 3)) + + +def test_conv_network(): + r"""The network and devices are as follows: + data1 data2 <--- CPU + | | + conv2d conv2d <--- CPU + \ / + \ / + add <--- GPU + | + conv2d <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 56), float32], + %weight: Tensor[(64, 64, 3, 3), float32]) { + %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %1 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %2 = on_device(%0, device_type=1); + %3 = on_device(%1, device_type=1); + %4 = add(%2, %3); + %5 = on_device(%4, device_type=2); + %6 = nn.conv2d(%5, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + on_device(%6, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 56), float32], + %weight: Tensor[(64, 64, 3, 3), float32], param_device_types=[1, 1, 1], result_device_type=1) { + %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %3 = on_device(%2, device_type=1, is_fixed=True); + %4 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %6 = add(%4, %5); + %7 = on_device(%6, device_type=2, is_fixed=True); + %8 = device_copy(%7, src_dev_type=2, dst_dev_type=1); + nn.conv2d(%8, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) + } + """ + ) + + # Don't try to execute, we don't have a reference conv2d + exercise(input(), expected(), None, None) + + +def test_tuple_get_item(): + # Note that the device copy should be placed after projection rather than before. This is handled by + # a heuristic in the pass. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(3, 3, 4), float32]) { + let %t = split(%x, indices_or_sections=3); + %0 = on_device(%t, device_type=1); + %1 = on_device(%t, device_type=1); + %2 = %0.0; + %3 = %1.1; + %4 = subtract(%2, %3); + on_device(%4, device_type=2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(3, 3, 4), float32], param_device_types=[1], result_device_type=2) { + %0 = split(%x, indices_or_sections=3); + let %t = on_device(%0, device_type=1, is_fixed=True); + %1 = %t.0; + %2 = on_device(%1, device_type=1, is_fixed=True); + %3 = %t.1; + %4 = on_device(%3, device_type=1, is_fixed=True); + %5 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + %6 = device_copy(%4, src_dev_type=1, dst_dev_type=2); + subtract(%5, %6) + } + """ + ) + + def ref(x): + t = np.split(x, 3) + return np.subtract(t[0], t[1]) + + exercise(input(), expected(), ref, rands((3, 3, 4), 1)) + + +def test_propogation(): + r""" The network and devices are as follows: + x <--- CPU + | + log <--- CPU + / \ + log2 log10 <--- GPU + \ / + add <--- GPU + | + tan <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32]) { + %0 = log(%x); + %1 = on_device(%0, device_type=1); + %2 = log2(%1); + %3 = on_device(%0, device_type=1); + %4 = log10(%3); + %5 = on_device(%2, device_type=2); + %6 = on_device(%4, device_type=2); + %7 = add(%5, %6); + %8 = on_device(%7, device_type=2); + %9 = tan(%8); + on_device(%9, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], param_device_types=[1], result_device_type=1) { + %0 = log(%x); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = on_device(%0, device_type=1, is_fixed=True); + %4 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %5 = log2(%2); + %6 = log10(%4); + %7 = add(%5, %6); + %8 = on_device(%7, device_type=2, is_fixed=True); + %9 = device_copy(%8, src_dev_type=2, dst_dev_type=1); + tan(%9) + } + """ + ) + + def ref(x): + y = np.log(x) + return np.tan(np.add(np.log2(y), np.log10(y))) + + exercise(input(), expected(), ref, rands((5, 7), 1)) + + +def test_fusible_network(): + r""" The network is as follows: + x y <--- GPU + \ / + add <--- GPU + / \ + negative \ <--- CPU + \ \ + \ negative <--- GPU + \ / + add <--- GPU + | + negative <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2); + %2 = negative(%1); + %3 = on_device(%2, device_type=1); + %4 = negative(%0); + %5 = add(%3, %4); + %6 = on_device(%5, device_type=2); + %7 = negative(%6); + on_device(%7, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], param_device_types=[2, 2], result_device_type=1) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2, is_fixed=True); + %2 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + %3 = negative(%2); + %4 = on_device(%3, device_type=1, is_fixed=True); + %5 = device_copy(%4, src_dev_type=1, dst_dev_type=2); + %6 = negative(%0); + %7 = add(%5, %6); + %8 = on_device(%7, device_type=2, is_fixed=True); + %9 = device_copy(%8, src_dev_type=2, dst_dev_type=1); + negative(%9) + } + """ + ) + + def ref(x, y): + z = np.add(x, y) + return np.negative(np.add(np.negative(z), np.negative(z))) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_unpropagatable_graph(): + r"""The network is as follows: + a b <--- CPU + \ / + \ / c d <--- GPU + \ / \ / + add \ / <--- CPU + \ \ / + \ multiply <--- GPU + \ / + subtract <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = multiply(%c, %d); + %2 = on_device(%0, device_type=1); + %3 = on_device(%1, device_type=2); + %4 = subtract(%2, %3); + on_device(%4, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=1) { + %0 = multiply(%c, %d); + %1 = on_device(%0, device_type=2, is_fixed=True); + %2 = add(%a, %b); + %3 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + subtract(%2, %3) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.multiply(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_conditional(): + # The conditional is over a function type, thus exercising the first-order/higher-order domain handling. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + let %f = fn (%a) { + %0 = on_device(%y, device_type=1, is_fixed=True); + add(%a, %0) + }; + let %g = fn (%a1) { + subtract(%a1, %y) + }; + let %h = if (%x) { + %f + } else { + %g + }; + %h(%z) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1], result_device_type=1) { + let %f = fn (%a, param_device_types=[1], result_device_type=1) { + add(%a, %y) + }; + let %g = fn (%a1, param_device_types=[1], result_device_type=1) { + subtract(%a1, %y) + }; + let %h = if (%x) { + %f + } else { + %g + }; + %h(%z) + } + """ + ) + + def ref(x, y, z): + def f(a): + return np.add(a, y) + + def g(a): + return np.subtract(a, y) + + h = f if x else g + return h(z) + + exercise(input(), expected(), ref, [True, rand((5, 7)), rand((5, 7))]) + + +def test_global(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = on_device(%b, device_type=1); + add(%a, %0) + } + + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + @f(%y, %x) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + param_device_types=[2, 1], result_device_type=2) -> Tensor[(5, 7), float32] { + %0 = device_copy(%b, src_dev_type=1, dst_dev_type=2); + add(%a, %0) + } + + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[1, 2], result_device_type=2) -> Tensor[(5, 7), float32] { + @f(%y, %x) + } + """ + ) + + def ref(x, y): + def f(a, b): + return np.add(a, b) + + return f(x, y) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_ref(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + let %r = ref(%x); + %0 = on_device(%y, device_type=1); + ref_write(%r, %0); + %1 = ref_read(%r); + add(%x, %1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[2, 1], result_device_type=2) { + let %r = ref(%x); + %0 = device_copy(%y, src_dev_type=1, dst_dev_type=2); + ref_write(%r, %0); + %1 = ref_read(%r); + add(%x, %1) + } + """ + ) + + def ref(x, y): + r = {"value": x} + r["value"] = y + return np.add(x, r["value"]) + + # Don't try to execute, no backend currently supports both hetrogeneous devices and references. + exercise(input(), expected(), None, None) + + +def test_adt(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + type List[A] { + Cons(A, List[A]), + Nil, + } + def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32]) { + %0 = on_device(%y, device_type=1, is_fixed=True); + %1 = Nil; + %2 = Cons(%0, %1); + let %l = Cons(%x, %2); + match? (%l) { + Cons(%z, _) => %z + } + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + type List[A] { + Cons(A, List[A]), + Nil, + } + def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=1) { + %0 = Nil; + %1 = Cons(%y, %0); + let %l = Cons(%x, %1); + match? (%l) { + Cons(%z, _) => %z + } + } + """ + ) + + def ref(x, y): + l = [x, y] + return l[0] + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 725ae75af4997ff3a5107cc82d64609773de23a0 Mon Sep 17 00:00:00 2001 From: masahi Date: Thu, 30 Sep 2021 02:30:13 +0900 Subject: [PATCH 20/20] Fix flaky NMS test by making sure scores are unique (#9140) --- tests/python/frontend/pytorch/test_forward.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index c27469edf1d7..9238acd5f049 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1963,8 +1963,9 @@ def _gen_rand_inputs(num_boxes): boxes = torch.rand(num_boxes, box_len, dtype=torch.float) * 0.5 boxes[:, 2] += boxes[:, 0] boxes[:, 3] += boxes[:, 1] - scores = torch.from_numpy(np.random.uniform(-1, 1, size=(num_boxes,)).astype(np.float32)) - return boxes, scores + scores = np.linspace(0, 1, num=num_boxes).astype("float32") + np.random.shuffle(scores) + return boxes, torch.from_numpy(scores) targets = ["llvm", "cuda"]