From 08ea4ee6b337a69295eb50413c2bf62afde86b58 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 30 Apr 2020 16:11:38 -0700 Subject: [PATCH] [IR] Initial stab at std::string->String upgrade (#5438) --- include/tvm/ir/span.h | 4 ++-- include/tvm/ir/type.h | 4 ++-- python/tvm/ir/json_compact.py | 27 ++++++++++++++++++++++++--- src/ir/span.cc | 18 +++++++++++------- src/ir/type.cc | 4 ++-- 5 files changed, 41 insertions(+), 16 deletions(-) diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 7194e903549c..411b733d2af9 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -40,7 +40,7 @@ class SourceName; class SourceNameNode : public Object { public: /*! \brief The source name. */ - std::string name; + String name; // override attr visitor void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } @@ -64,7 +64,7 @@ class SourceName : public ObjectRef { * \param name Name of the operator. * \return SourceName valid throughout program lifetime. */ - TVM_DLL static SourceName Get(const std::string& name); + TVM_DLL static SourceName Get(const String& name); TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode); }; diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 0ef03c42c03e..a20dbddb5c77 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -227,7 +227,7 @@ class TypeVarNode : public TypeNode { * this only acts as a hint to the user, * and is not used for equality. */ - std::string name_hint; + String name_hint; /*! \brief The kind of type parameter */ TypeKind kind; @@ -263,7 +263,7 @@ class TypeVar : public Type { * \param name_hint The name of the type var. * \param kind The kind of the type var. */ - TVM_DLL TypeVar(std::string name_hint, TypeKind kind); + TVM_DLL TypeVar(String name_hint, TypeKind kind); TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode); }; diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 9a881cfa6d5b..fcea9d821222 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -16,6 +16,9 @@ # under the License. """Tool to upgrade json from historical versions.""" import json +import tvm.ir +import tvm.runtime + def create_updater(node_map, from_ver, to_ver): """Create an updater to update json loaded data. @@ -41,8 +44,12 @@ def _updater(data): nodes = data["nodes"] for idx, item in enumerate(nodes): f = node_map.get(item["type_key"], None) - if f: - nodes[idx] = f(item, nodes) + if isinstance(f, list): + for fpass in f: + item = fpass(item, nodes) + elif f: + item = f(item, nodes) + nodes[idx] = item data["attrs"]["tvm_version"] = to_ver return data return _updater @@ -84,12 +91,26 @@ def _update_global_key(item, _): del item["global_key"] return item + def _update_from_std_str(key): + def _convert(item, nodes): + str_val = item["attrs"][key] + jdata = json.loads(tvm.ir.save_json(tvm.runtime.String(str_val))) + root_idx = jdata["root"] + val = jdata["nodes"][root_idx] + sidx = len(nodes) + nodes.append(val) + item["attrs"][key] = '%d' % sidx + return item + + return _convert + + node_map = { # Base IR "SourceName": _update_global_key, "EnvFunc": _update_global_key, "relay.Op": _update_global_key, - "relay.TypeVar": _ftype_var, + "relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")], "relay.GlobalTypeVar": _ftype_var, "relay.Type": _rename("Type"), "relay.TupleType": _rename("TupleType"), diff --git a/src/ir/span.cc b/src/ir/span.cc index f84353de2a8b..5a06a1014645 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -25,10 +25,10 @@ namespace tvm { -ObjectPtr GetSourceNameNode(const std::string& name) { +ObjectPtr GetSourceNameNode(const String& name) { // always return pointer as the reference can change as map re-allocate. // or use another level of indirection by creating a unique_ptr - static std::unordered_map > source_map; + static std::unordered_map > source_map; auto sn = source_map.find(name); if (sn == source_map.end()) { @@ -41,7 +41,11 @@ ObjectPtr GetSourceNameNode(const std::string& name) { } } -SourceName SourceName::Get(const std::string& name) { +ObjectPtr GetSourceNameNodeByStr(const std::string& name) { + return GetSourceNameNode(name); +} + +SourceName SourceName::Get(const String& name) { return SourceName(GetSourceNameNode(name)); } @@ -55,10 +59,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_REGISTER_NODE_TYPE(SourceNameNode) -.set_creator(GetSourceNameNode) -.set_repr_bytes([](const Object* n) { - return static_cast(n)->name; - }); +.set_creator(GetSourceNameNodeByStr) +.set_repr_bytes([](const Object* n) -> std::string { + return static_cast(n)->name; +}); Span SpanNode::make(SourceName source, int lineno, int col_offset) { auto n = make_object(); diff --git a/src/ir/type.cc b/src/ir/type.cc index 5b038218c127..5d4689377fc8 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -66,7 +66,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); -TypeVar::TypeVar(std::string name, TypeKind kind) { +TypeVar::TypeVar(String name, TypeKind kind) { ObjectPtr n = make_object(); n->name_hint = std::move(name); n->kind = std::move(kind); @@ -76,7 +76,7 @@ TypeVar::TypeVar(std::string name, TypeKind kind) { TVM_REGISTER_NODE_TYPE(TypeVarNode); TVM_REGISTER_GLOBAL("ir.TypeVar") -.set_body_typed([](std::string name, int kind) { +.set_body_typed([](String name, int kind) { return TypeVar(name, static_cast(kind)); });