From 2d37f2cbd9a802daa2e8695d272551d86b22bd7a Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Tue, 28 Sep 2021 10:27:23 -0700 Subject: [PATCH] [Relay] Prepare for new plan_devices.cc (part II) (#9130) * Prepare for new plan_devices.cc (part II) These changes came from changing https://github.com/apache/tvm/pull/9038 to use tvm.parser.fromtext instead of manual AST construction. - Demote FunctionOnDeviceAttrs to just a pair of DictAttrs entries so that the parser will understand them on Function definitions. - Connect some special operators to their attributes so parsing understands them at call sites. - Don't silently ignore attributes during parsing. - Implement OptFunctionOnDevice so won't add device annotations for kUnknownDeviceType. - Allow the parser to be given an initial metadata map to support examples which need constants. - More DLOG -> VLOG conversions to reduce debug clutter. * [checkpoint] Keep existing ParseModule ffi to simplify rust bindings * [checkpoint] Address Christopher's comments. * [checkpoint] Andrew's comments from #9038 * [checkpoint] Jared's comments from #9038 * [checkpoint] Woops, forgot rename. --- include/tvm/ir/attrs.h | 25 +++++++ include/tvm/ir/function.h | 21 ++++++ include/tvm/parser/parser.h | 8 ++- include/tvm/relay/attrs/annotation.h | 41 ++++++++++-- include/tvm/relay/attrs/function.h | 66 ------------------- python/tvm/parser/__init__.py | 6 +- src/ir/diagnostic.cc | 6 +- src/parser/meta_ref.h | 3 +- src/parser/parser.cc | 52 ++++++++++----- src/parser/source_map.cc | 6 +- src/relay/op/annotation/annotation.cc | 50 +++++++------- src/relay/op/annotation/annotation.h | 23 +++++-- src/relay/op/memory/device_copy.cc | 1 + src/relay/op/memory/device_copy.h | 2 +- src/relay/op/memory/memory.cc | 2 + src/relay/op/vm/vm.cc | 3 + src/relay/transforms/type_infer.cc | 1 - .../relay/op/annotation/test_annotation.py | 11 ++-- tests/python/relay/test_ir_parser.py | 25 ++++++- 19 files changed, 211 insertions(+), 141 deletions(-) delete mode 100644 include/tvm/relay/attrs/function.h diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index fa1861051e2f..715c96eb6ea5 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -357,6 +357,31 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v return input; } +/*! + * \brief Copy the function or module, but overrides the attributes with the entries from \p attrs. + * + * \param input The thing to annotate (BaseFunc or IRModule) + * \param attrs Key/values attributes to add to \p input. + * + * \tparam TFunc The corresponding function or module type. + * + * \returns The new function or module with updated attributes. + */ +template +inline TFunc WithAttrs(TFunc input, Map attrs) { + using TNode = typename TFunc::ContainerType; + static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); + TNode* node = input.CopyOnWrite(); + if (node->attrs.defined()) { + for (const auto& pair : attrs) { + node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second); + } + } else { + node->attrs = DictAttrs(std::move(attrs)); + } + return input; +} + // Namespace containing detail implementations namespace detail { using runtime::TVMArgValue; diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 13b984d9cb35..5ee719f9964f 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -189,6 +189,27 @@ constexpr const char* kTarget = "target"; * Type: String */ constexpr const char* kGlobalSymbol = "global_symbol"; + +/*! + * \brief The device type which will hold each of the functions parameters. + * + * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but + * may be included as an annotation on user programs. + * + * Type: Array (but interpreted as Array) + */ +constexpr const char* kParamDeviceTypes = "param_device_types"; + +/*! + * \brief The device type which will hold the function result. + * + * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but + * may be included as an annotation on user programs. + * + * Type: Integer (but interpreted as DLDeviceType) + */ +constexpr const char* kResultDeviceType = "result_device_type"; + } // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/include/tvm/parser/parser.h b/include/tvm/parser/parser.h index 7673eec2a337..8c2722050905 100644 --- a/include/tvm/parser/parser.h +++ b/include/tvm/parser/parser.h @@ -23,6 +23,7 @@ * \file parser.h * \brief A parser for TVM IR. */ +#include #include #include @@ -32,8 +33,11 @@ namespace tvm { namespace parser { -IRModule ParseModule(std::string file_name, std::string file_content, - Optional init_module = Optional()); +using MetaTable = Map>; + +IRModule ParseModule(const std::string& file_name, const std::string& file_content, + const Optional& init_module = Optional(), + const MetaTable& init_meta_table = MetaTable()); } // namespace parser } // namespace tvm diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index bc55965ee852..85ac3f36ff60 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -32,17 +32,44 @@ namespace tvm { namespace relay { /*! - * \brief Attributes for the "on_device" operator. + * \brief Attributes for the "on_device" special operator. * - * The relay call + * The Relay call (aka 'annotation'): * \code - * on_device(expr, device_type=2) + * on_device(sub_expr, device_type=2) * \endcode - * denotes that the result of \p expr should be stored on the device with \p DLDeviceType 2 - * (i.e. \p kDLCuda). Semantically the operator is the identity function. + * constrains \p sub_expr to execute and store its result on a device with \p DLDeviceType \p 2 + * (i.e. a \p kDLCuda device). However the annotation itself may appear in an expression to be + * executed and stored on a different device. If so the compiler will automatically insert a + * "device_copy" call to mediate the transition between devices. * - * See also FunctionOnDeviceAttrs in include/relay/attrs/function.h for the function-level - * companion. + * E.g.: Assuming %x and %y reside on the GPU and %z on the CPU then: + * \code + * multiply(on_device(add(%x, %y), device_type=2), %z) + * \endcode + * indicates the \p add should execute on the GPU but the \p multiply should execute on the CPU. + * The compiler will rewrite this to: + * \code + * multiply(device_copy(add(%x, %y), src_dev_type=2, dst_dev_type=1), %z) + * \endcode + * + * The Relay call + * \code + * on_device(sub_expr, device_type=2, is_fixed=True) + * \endcode + * is similar to the above, however the annotation itself must appear in an expression on the + * same device. The compiler will check the devices are consistent, and will not insert any + * "device_copy" call. This form of annotation shouldn't be necessary in user programs. However + * it is needed by the \p PlanDevices pass to fully specify the results of device planning so that + * the pass is idempotent. + * + * E.g.: The following program is equivalent to the above: + * \code + * let %a = on_device(add(%x, %y), device_type=2, is_fixed=True) + * multiply(device_copy(%a, src_dev_type=2, dst_dev_type=1), %z) + * \endcode + * The "on_device" annotation with \p is_fixed=True indicates unambiguously that \p %a is stored + * on the GPU. */ struct OnDeviceAttrs : public tvm::AttrsNode { // TODO(mbs): Replace device types with TargetDevice. diff --git a/include/tvm/relay/attrs/function.h b/include/tvm/relay/attrs/function.h deleted file mode 100644 index f4f94131da1f..000000000000 --- a/include/tvm/relay/attrs/function.h +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/relay/attrs/function.h - * \brief Attributes for Relay Functions which don't make sense on PrimFuncs. - */ -#ifndef TVM_RELAY_ATTRS_FUNCTION_H_ -#define TVM_RELAY_ATTRS_FUNCTION_H_ - -namespace tvm { -namespace relay { -/*! - * \brief Attributes for Relay function definitions which capture the devices for the - * function parameters and result. - * - * See also OnDeviceAttrs in include/tvm/relay/attrs/annotation.h for the companion "on_device" - * call attributes. - */ -struct FunctionOnDeviceAttrs : public tvm::AttrsNode { - /*! \brief Device type on which each of the function's arguments already resides. */ - Array param_device_types; - // TODO(mbs): Replace device types with TargetDevice. - /*! \brief Device type on which function body should be evaluated. */ - int result_device_type = kInvalidDeviceType; - - TVM_DECLARE_ATTRS(FunctionOnDeviceAttrs, "relay.attrs.FunctionOnDeviceAttrs") { - TVM_ATTR_FIELD(param_device_types) - .describe("The type of the virtual device which holds each function parameters."); - TVM_ATTR_FIELD(result_device_type) - .describe("The type of the virtual device which will hold the function's result.") - .set_default(0); - } -}; - -namespace attr { - -/*! - * \brief Device annotations for function parameters and results. - * - * Type: FunctionOnDeviceAttrs - */ -constexpr static const char* kFunctionAttrsKey = "on_device"; - -} // namespace attr - -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_ATTRS_FUNCTION_H_ diff --git a/python/tvm/parser/__init__.py b/python/tvm/parser/__init__.py index 60fcddb17f08..d75ad16ebab2 100644 --- a/python/tvm/parser/__init__.py +++ b/python/tvm/parser/__init__.py @@ -26,8 +26,10 @@ def add(self, name, content): return _ffi.get_global_func("SourceMapAdd")(self, name, content) -def parse(source, source_name="from_string"): - return _ffi_api.ParseModule(source_name, source) +def parse(source, source_name="from_string", init_module=None, init_meta_table=None): + if init_meta_table is None: + init_meta_table = {} + return _ffi_api.ParseModuleInContext(source_name, source, init_module, init_meta_table) def parse_expr(source): diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 876113b85f6e..b9677d198eba 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -242,10 +242,10 @@ void ReportAt(const DiagnosticContext& context, std::ostream& out, const Span& s } auto source = (*it).second; - DLOG(INFO) << "Source: " << std::endl << source->source; + VLOG(1) << "Source: " << std::endl << source->source; - DLOG(INFO) << "ReportAt " - << "span = " << span << " msg = " << diagnostic->message; + VLOG(1) << "ReportAt " + << "span = " << span << " msg = " << diagnostic->message; auto line_text = source.GetLine(span->line); diff --git a/src/parser/meta_ref.h b/src/parser/meta_ref.h index 481f334cb0fe..483b7f726e07 100644 --- a/src/parser/meta_ref.h +++ b/src/parser/meta_ref.h @@ -26,6 +26,7 @@ #define TVM_PARSER_META_REF_H_ #include +#include #include #include @@ -36,8 +37,6 @@ namespace parser { using namespace relay; -using MetaTable = Map>; - /*! * \brief Options for allocating storage. */ diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 93dc687d72f5..5eec716cc20c 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1092,8 +1092,6 @@ class Parser { Array generics; if (Peek()->token_type == TokenType::kLSquare) { - // If we have generics we need to add a type scope. - PushTypeScope(); generics = ParseSequence( TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { auto type_var_name = Match(TokenType::kIdentifier).ToString(); @@ -1444,6 +1442,10 @@ class Parser { ICHECK(attr_obj.defined()); attrs = Downcast(attr_obj); } + } else { + this->diag_ctx.EmitFatal(Diagnostic::Error(op->span) + << "unable to determine the 'attrs_type_key' with which " + "to represent the call attributes for this operator"); } } return true; @@ -1867,7 +1869,7 @@ class Parser { }; Parser InitParser(const std::string& file_name, const std::string& file_content, - Optional init_module) { + const Optional& init_module, const MetaTable& init_meta_table) { VLOG(0) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size(); SourceName src_name = SourceName::Get(file_name); Source source(src_name, file_content); @@ -1886,19 +1888,33 @@ Parser InitParser(const std::string& file_name, const std::string& file_content, auto tokens_and_table = Tokenize(diag_ctx, source); auto tokens = tokens_and_table.first; - auto meta_data_table = tokens_and_table.second; + MetaTable meta_data_table = tokens_and_table.second.ToMetadata(); + + // Merge any entries in init_meta_table into anything captured in the #[metadata] section + // of the file_content. Metadata references within file_content must use indexes which account + // for this ordering. + for (const auto& pair : init_meta_table) { + Array items; + if (meta_data_table.count(pair.first)) { + items = meta_data_table[pair.first]; + } + for (const auto& obj : pair.second) { + items.push_back(obj); + } + meta_data_table.Set(pair.first, items); + } - return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), meta_data_table.ToMetadata()); + return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), std::move(meta_data_table)); } -IRModule ParseModule(std::string file_name, std::string file_content, - Optional init_module) { +IRModule ParseModule(const std::string& file_name, const std::string& file_content, + const Optional& init_module, const MetaTable& init_meta_table) { VLOG(0) << "ParseModule"; - auto parser = InitParser(file_name, file_content, init_module); + auto parser = InitParser(file_name, file_content, init_module, init_meta_table); auto mod = parser.ParseModule(); ICHECK(mod.defined()) << "The parser must return a non-null module."; - // NB(@jroesch): it is very important that we render any errors before we procede - // if there were any errors which allow the parser to procede we must render them + // NB(@jroesch): it is very important that we render any errors before we proceed + // if there were any errors which allow the parser to proceed we must render them // here. parser.diag_ctx.Render(); auto infer_type = tvm::relay::transform::InferType(); @@ -1906,22 +1922,28 @@ IRModule ParseModule(std::string file_name, std::string file_content, return infer_type(mod); } -Expr ParseExpr(std::string file_name, std::string file_content) { +Expr ParseExpr(const std::string& file_name, const std::string& file_content) { VLOG(0) << "ParseExpr"; - auto parser = InitParser(file_name, file_content, Optional()); + auto parser = InitParser(file_name, file_content, Optional(), MetaTable()); parser.ParseSemVer(false); parser.PushScope(); auto expr = parser.ParseExpr(); parser.Match(TokenType::kEndOfFile); - // NB(@jroesch): it is very important that we render any errors before we procede - // if there were any errors which allow the parser to procede we must render them + // NB(@jroesch): it is very important that we render any errors before we proceed + // if there were any errors which allow the parser to proceed we must render them // here. parser.diag_ctx.Render(); return expr; } +TVM_REGISTER_GLOBAL("parser.ParseModuleInContext") + .set_body_typed([](const std::string& file_name, const std::string& file_content, + const Optional& init_module, const MetaTable& init_meta_table) { + return ParseModule(file_name, file_content, init_module, init_meta_table); + }); + TVM_REGISTER_GLOBAL("parser.ParseModule") - .set_body_typed([](tvm::String file_name, tvm::String file_content) { + .set_body_typed([](const std::string& file_name, const std::string& file_content) { return ParseModule(file_name, file_content); }); diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index 4e79d0e74c59..3c1329670c40 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -60,7 +60,7 @@ Source::Source(SourceName src_name, std::string source) { } tvm::String Source::GetLine(int line) { - DLOG(INFO) << "Source::GetLine: line=" << line; + VLOG(1) << "Source::GetLine: line=" << line; ICHECK(line - 1 < static_cast((*this)->line_map.size())) << "requested line: " << line << "at index: " << (line - 1) << "line_map size: " << (*this)->line_map.size() << "source: " << (*this)->source; @@ -69,10 +69,10 @@ tvm::String Source::GetLine(int line) { auto range = (*this)->line_map.at(line - 1); int line_start = range.first; int line_length = range.second; - DLOG(INFO) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length; + VLOG(1) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length; // TODO(@jroesch): expose substring on tvm::String. auto line_text = std::string((*this)->source).substr(line_start, line_length); - DLOG(INFO) << "Source::GetLine: line_text=" << line_text; + VLOG(1) << "Source::GetLine: line_text=" << line_text; return line_text; } diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index 4eda15937f3a..284f8b88ee0d 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -26,7 +26,6 @@ #include "./annotation.h" #include -#include #include #include #include @@ -54,7 +53,7 @@ Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, span); } -Expr OptOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { +Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { if (device_type == kInvalidDeviceType) { // Undefined signals no annotation is required. return expr; @@ -92,6 +91,7 @@ RELAY_REGISTER_OP("on_device") .add_argument("data", "Tensor", "The input data.") .set_support_level(10) .add_type_rel("Identity", IdentityRel) + .set_attrs_type_key("relay.attrs.OnDeviceAttrs") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) @@ -128,14 +128,10 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) { return {}; } -TVM_REGISTER_NODE_TYPE(FunctionOnDeviceAttrs); - Function FunctionOnDevice(Function function, Array param_device_types, - DLDeviceType result_device_type) { - auto attrs = make_object(); - attrs->param_device_types = std::move(param_device_types); - attrs->result_device_type = result_device_type; - return WithAttr(std::move(function), attr::kFunctionAttrsKey, Attrs(std::move(attrs))); + Integer result_device_type) { + return WithAttrs(std::move(function), {{tvm::attr::kParamDeviceTypes, param_device_types}, + {tvm::attr::kResultDeviceType, result_device_type}}); } Function FunctionOnDevice(Function function, const std::vector& param_device_types, @@ -143,9 +139,21 @@ Function FunctionOnDevice(Function function, const std::vector& pa Array arr; arr.reserve(param_device_types.size()); for (const auto device_type : param_device_types) { - arr.push_back(static_cast(device_type)); + arr.push_back(static_cast(device_type)); + } + return FunctionOnDevice(std::move(function), std::move(arr), + static_cast(result_device_type)); +} + +Function MaybeFunctionOnDevice(Function function, + const std::vector& param_device_types, + DLDeviceType result_device_type) { + if (std::all_of(param_device_types.begin(), param_device_types.end(), + [](DLDeviceType type) { return type == kInvalidDeviceType; }) && + result_device_type == kInvalidDeviceType) { + return function; } - return FunctionOnDevice(function, arr, result_device_type); + return FunctionOnDevice(function, param_device_types, result_device_type); } TVM_REGISTER_GLOBAL("relay.op.annotation._make.function_on_device") @@ -156,32 +164,26 @@ TVM_REGISTER_GLOBAL("relay.op.annotation._make.function_on_device") }); DLDeviceType GetFunctionResultDeviceType(const FunctionNode* function_node) { - auto opt_attrs = function_node->GetAttr(attr::kFunctionAttrsKey); - if (!opt_attrs) { + auto opt_integer = function_node->GetAttr(tvm::attr::kResultDeviceType); + if (!opt_integer) { // No annotation. return kInvalidDeviceType; } - const auto* opt_function_on_device_attrs = opt_attrs.value().as(); - ICHECK(opt_function_on_device_attrs != nullptr) - << "function '" << attr::kFunctionAttrsKey << "' annotation must be a FunctionOnDeviceAttrs"; - return static_cast(opt_function_on_device_attrs->result_device_type); + return static_cast(opt_integer.value()->value); } DLDeviceType GetFunctionParamDeviceType(const FunctionNode* function_node, size_t i) { ICHECK_LT(i, function_node->params.size()) << "param index " << i << " out of range for function of arity " << function_node->params.size(); - auto opt_attrs = function_node->GetAttr(attr::kFunctionAttrsKey); - if (!opt_attrs) { + auto opt_array = function_node->GetAttr>(tvm::attr::kParamDeviceTypes); + if (!opt_array) { // No annotation. return kInvalidDeviceType; } - const auto* opt_function_on_device_attrs = opt_attrs.value().as(); - ICHECK(opt_function_on_device_attrs != nullptr) - << "function '" << attr::kFunctionAttrsKey << "' annotation must be a FunctionOnDeviceAttrs"; - ICHECK_EQ(opt_function_on_device_attrs->param_device_types.size(), function_node->params.size()) + ICHECK_EQ(opt_array.value().size(), function_node->params.size()) << "annotation parameters do not match function arity"; - return static_cast(opt_function_on_device_attrs->param_device_types[i]->value); + return static_cast(opt_array.value()[i]->value); } Expr StopFusion(Expr data) { diff --git a/src/relay/op/annotation/annotation.h b/src/relay/op/annotation/annotation.h index e3a4aea4708c..643a82116b5b 100644 --- a/src/relay/op/annotation/annotation.h +++ b/src/relay/op/annotation/annotation.h @@ -39,6 +39,8 @@ const Op& OnDeviceOp(); /*! * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed. + * + * See \p OnDeviceAttrs for an overview. */ Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); @@ -52,7 +54,7 @@ Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); * - \p expr is a constructor. There should probably be device polymorphic but are in an * in-between state at the moment. */ -Expr OptOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); +Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); /*! \brief Result of \p GetOnDeviceProps. */ struct OnDeviceProps { @@ -83,24 +85,31 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr); inline bool IsOnDeviceCall(const Expr& expr) { return GetOnDeviceProps(expr).body.defined(); } /*! - * \brief Returns \p function annotated with "on_device" attributes capturing parameter and result - * devices types. However returns \p function directly if all device types are \p - * kInvalidDeviceType. + * \brief Returns \p function annotated with "param_device_types" and "result_device_type" + * attributes capturing parameter and result devices types respectively. */ Function FunctionOnDevice(Function function, Array param_device_types, - DLDeviceType body_device_type); + Integer body_device_type); Function FunctionOnDevice(Function function, const std::vector& param_device_types, DLDeviceType body_device_type); +/*! + * \brief As for \p FunctionOnDevice, but returns \p function unchanged if all parameters and + * result device types are \p kInvalidDeviceType. + */ +Function MaybeFunctionOnDevice(Function function, + const std::vector& param_device_types, + DLDeviceType result_device_type); + /*! * \brief Returns the device type for the resut of \p function_node, or \p kInvalidDeviceType - * if function does not have "on_device" annotation. + * if function does not have "result_device_type" annotation. */ DLDeviceType GetFunctionResultDeviceType(const FunctionNode* function_node); /*! * \brief Returns the device type for the \p i'th parameter of \p function_node, or - * \p kInvalidDeviceType if function does not have "on_device" annotation. + * \p kInvalidDeviceType if function does not have "param_device_types" annotation. */ DLDeviceType GetFunctionParamDeviceType(const FunctionNode* function_node, size_t i); diff --git a/src/relay/op/memory/device_copy.cc b/src/relay/op/memory/device_copy.cc index b94caac2c3d9..dce89aa91b65 100644 --- a/src/relay/op/memory/device_copy.cc +++ b/src/relay/op/memory/device_copy.cc @@ -76,6 +76,7 @@ on different devices. .add_argument("data", "Tensor", "The input data.") .set_support_level(10) .add_type_rel("Identity", IdentityRel) + .set_attrs_type_key("relay.attrs.DeviceCopyAttrs") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) diff --git a/src/relay/op/memory/device_copy.h b/src/relay/op/memory/device_copy.h index d590d8510f17..d21fdb6abe19 100644 --- a/src/relay/op/memory/device_copy.h +++ b/src/relay/op/memory/device_copy.h @@ -45,7 +45,7 @@ Expr DeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type) * a device of type \p src_dev_type but then copied to a device of type \p dst_dev_type. * However, return \p expr directly if \p src_dev_type equals \p dst_dev_type. */ -Expr OptDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type); +Expr MaybeDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type); /*! \brief Result of \p GetDeviceCopyProps. */ struct DeviceCopyProps { diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 68a83ebba1fe..5339d48e3a2f 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -86,6 +86,7 @@ RELAY_REGISTER_OP("memory.alloc_storage") .add_argument("size", "Tensor", "The size of the storage to allocate.") .add_argument("alignment", "Tensor", "The alignment of the storage.") .add_type_rel("AllocStorage", AllocStorageRel) + .set_attrs_type_key("relay.attrs.AllocStorageAttrs") .set_support_level(10) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) @@ -200,6 +201,7 @@ RELAY_REGISTER_OP("memory.alloc_tensor") .add_argument("offset", "Tensor", "The offset into the backing storage.") .add_argument("shape", "Tensor", "The shape of the tensor to allocate.") .add_type_rel("AllocTensor", AllocTensorRel) + .set_attrs_type_key("relay.attrs.AllocTensorAttrs") .set_support_level(10) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) diff --git a/src/relay/op/vm/vm.cc b/src/relay/op/vm/vm.cc index a74a259a114f..be31b5482937 100644 --- a/src/relay/op/vm/vm.cc +++ b/src/relay/op/vm/vm.cc @@ -50,6 +50,7 @@ RELAY_REGISTER_OP("vm.shape_of") .set_num_inputs(1) .add_argument("tensor", "Tensor", "The input tensor") .add_type_rel("ShapeOf", ShapeOfRel) + .set_attrs_type_key("relay.attrs.ShapeOfAttrs") .set_support_level(10) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) @@ -131,6 +132,7 @@ RELAY_REGISTER_OP("vm.shape_func") .add_argument("func", "Function", "The operation to call") .add_argument("ins", "Tuple", "The input tensors.") .add_argument("outs", "Tuple", "The output tensors.") + .set_attrs_type_key("relay.attrs.ShapeFuncAttrs") .add_type_rel("ShapeFuncRel", ShapeFuncRel) .set_support_level(10) .set_attr("TOpPattern", kOpaque) @@ -214,6 +216,7 @@ RELAY_REGISTER_OP("vm.reshape_tensor") .add_argument("data", "Tensor", "The input tensor") .add_argument("shape", "Tensor", "The output shape tensor") .add_type_rel("ReshapeTensor", ReshapeTensorRel) + .set_attrs_type_key("relay.attrs.ReshapeTensorAttrs") .set_support_level(10) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 6c2371716b16..ebdf1fed2fab 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -824,7 +824,6 @@ Pass InferType() { auto pass_info = PassInfo(0, "InferType", {}); return tvm::transform::CreateModulePass( [=](IRModule mod, const PassContext& pass_ctx) { - DLOG(INFO) << "tvm::relay::transform::InferType"; // Execute the pass function and return a new module. IRModule updated_mod = mod->ShallowCopy(); diff --git a/tests/python/relay/op/annotation/test_annotation.py b/tests/python/relay/op/annotation/test_annotation.py index 51daa9aaa06a..58e559eb9680 100644 --- a/tests/python/relay/op/annotation/test_annotation.py +++ b/tests/python/relay/op/annotation/test_annotation.py @@ -54,13 +54,10 @@ def test_function_on_device(): f = relay.Function([x, y], relay.add(x, y)) func = relay.annotation.function_on_device(f, ["cpu", "cuda"], "cuda") assert isinstance(func, relay.Function) - assert len(func.attrs["on_device"].param_device_types) == 2 - assert func.attrs["on_device"].param_device_types[0] == 1 - # ie kDLCPU - assert func.attrs["on_device"].param_device_types[1] == 2 - # ie kDLCUDA - assert func.attrs["on_device"].result_device_type == 2 - # ie KDLCUDA + assert len(func.attrs["param_device_types"]) == 2 + assert func.attrs["param_device_types"][0] == 1 # ie kDLCPU + assert func.attrs["param_device_types"][1] == 2 # ie kDLCUDA + assert func.attrs["result_device_type"] == 2 # ie KDLCUDA if __name__ == "__main__": diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 099e127aeba9..fdbd3924ffb7 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -23,7 +23,6 @@ from numpy import isclose from typing import Union - SEMVER = '#[version = "0.0.5"]\n' BINARY_OPS = { @@ -967,6 +966,30 @@ def test_func_attrs(): assert_parses_as(func.astext(), func) +def test_init_module_and_metatable(): + init_metatable = {"relay.Constant": [relay.const(np.random.rand(2, 3), dtype="float32")]} + init_module = tvm.parser.fromtext( + SEMVER + + """ + def @f(%y : Tensor[(2, 3), float32]) -> Tensor[(2, 3), float32] { + negative(%y) + } + """, + ) + mod = tvm.parser.parse( + SEMVER + + """ + def @main(%x: Tensor[(2, 3), float32]) { + add(@f(%x), meta[relay.Constant][0]) + } + """, + "from_string", + init_module, + init_metatable, + ) + roundtrip(mod) + + if __name__ == "__main__": import sys