From 738bb295d5143c8d4381f3ad1b3c0ac90ed1eea4 Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Mon, 27 Sep 2021 15:20:47 -0700 Subject: [PATCH] [checkpoint] Jared's comments from #9038 --- include/tvm/ir/attrs.h | 25 +++++++++++++++++++++++++ src/parser/parser.cc | 4 +--- src/relay/op/annotation/annotation.cc | 4 ++-- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index fa1861051e2f..715c96eb6ea5 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -357,6 +357,31 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v return input; } +/*! + * \brief Copy the function or module, but overrides the attributes with the entries from \p attrs. + * + * \param input The thing to annotate (BaseFunc or IRModule) + * \param attrs Key/values attributes to add to \p input. + * + * \tparam TFunc The corresponding function or module type. + * + * \returns The new function or module with updated attributes. + */ +template +inline TFunc WithAttrs(TFunc input, Map attrs) { + using TNode = typename TFunc::ContainerType; + static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); + TNode* node = input.CopyOnWrite(); + if (node->attrs.defined()) { + for (const auto& pair : attrs) { + node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second); + } + } else { + node->attrs = DictAttrs(std::move(attrs)); + } + return input; +} + // Namespace containing detail implementations namespace detail { using runtime::TVMArgValue; diff --git a/src/parser/parser.cc b/src/parser/parser.cc index b5fd52e7f1da..5eec716cc20c 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1088,12 +1088,10 @@ class Parser { VLOG(0) << "Parser::ParseFunctionDef"; return WithSpan([&]() { PushScope(); - PushTypeScope(); // TODO(mbs): BUG? + PushTypeScope(); Array generics; if (Peek()->token_type == TokenType::kLSquare) { - // If we have generics we need to add a type scope. - PushTypeScope(); generics = ParseSequence( TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { auto type_var_name = Match(TokenType::kIdentifier).ToString(); diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index 4bcf7024b901..474e8097f3cc 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -130,8 +130,8 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) { Function FunctionOnDevice(Function function, Array param_device_types, Integer result_device_type) { - return WithAttr(WithAttr(std::move(function), tvm::attr::kParamDeviceTypes, param_device_types), - tvm::attr::kResultDeviceType, result_device_type); + return WithAttrs(std::move(function), {{tvm::attr::kParamDeviceTypes, param_device_types}, + {tvm::attr::kResultDeviceType, result_device_type}}); } Function FunctionOnDevice(Function function, const std::vector& param_device_types,