Skip to content

Commit

Permalink
[checkpoint] Jared's comments from apache#9038
Browse files Browse the repository at this point in the history
  • Loading branch information
mbs-octoml committed Sep 27, 2021
1 parent 5553be2 commit 738bb29
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
25 changes: 25 additions & 0 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,31 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v
return input;
}

/*!
* \brief Copy the function or module, but overrides the attributes with the entries from \p attrs.
*
* \param input The thing to annotate (BaseFunc or IRModule)
* \param attrs Key/values attributes to add to \p input.
*
* \tparam TFunc The corresponding function or module type.
*
* \returns The new function or module with updated attributes.
*/
template <typename TFunc>
inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = input.CopyOnWrite();
if (node->attrs.defined()) {
for (const auto& pair : attrs) {
node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
}
} else {
node->attrs = DictAttrs(std::move(attrs));
}
return input;
}

// Namespace containing detail implementations
namespace detail {
using runtime::TVMArgValue;
Expand Down
4 changes: 1 addition & 3 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1088,12 +1088,10 @@ class Parser {
VLOG(0) << "Parser::ParseFunctionDef";
return WithSpan<Function>([&]() {
PushScope();
PushTypeScope(); // TODO(mbs): BUG?
PushTypeScope();

Array<TypeVar> generics;
if (Peek()->token_type == TokenType::kLSquare) {
// If we have generics we need to add a type scope.
PushTypeScope();
generics = ParseSequence<TypeVar>(
TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() {
auto type_var_name = Match(TokenType::kIdentifier).ToString();
Expand Down
4 changes: 2 additions & 2 deletions src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) {

Function FunctionOnDevice(Function function, Array<Integer> param_device_types,
Integer result_device_type) {
return WithAttr(WithAttr(std::move(function), tvm::attr::kParamDeviceTypes, param_device_types),
tvm::attr::kResultDeviceType, result_device_type);
return WithAttrs(std::move(function), {{tvm::attr::kParamDeviceTypes, param_device_types},
{tvm::attr::kResultDeviceType, result_device_type}});
}

Function FunctionOnDevice(Function function, const std::vector<DLDeviceType>& param_device_types,
Expand Down

0 comments on commit 738bb29

Please sign in to comment.