Skip to content

Commit

Permalink
[IR] Initial stab at std::string->String upgrade (apache#5438)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and trevor-m committed Jun 18, 2020
1 parent c39540b commit 08ea4ee
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 16 deletions.
4 changes: 2 additions & 2 deletions include/tvm/ir/span.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class SourceName;
class SourceNameNode : public Object {
public:
/*! \brief The source name. */
std::string name;
String name;
// override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }

Expand All @@ -64,7 +64,7 @@ class SourceName : public ObjectRef {
* \param name Name of the operator.
* \return SourceName valid throughout program lifetime.
*/
TVM_DLL static SourceName Get(const std::string& name);
TVM_DLL static SourceName Get(const String& name);

TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode);
};
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class TypeVarNode : public TypeNode {
* this only acts as a hint to the user,
* and is not used for equality.
*/
std::string name_hint;
String name_hint;
/*! \brief The kind of type parameter */
TypeKind kind;

Expand Down Expand Up @@ -263,7 +263,7 @@ class TypeVar : public Type {
* \param name_hint The name of the type var.
* \param kind The kind of the type var.
*/
TVM_DLL TypeVar(std::string name_hint, TypeKind kind);
TVM_DLL TypeVar(String name_hint, TypeKind kind);

TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode);
};
Expand Down
27 changes: 24 additions & 3 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
# under the License.
"""Tool to upgrade json from historical versions."""
import json
import tvm.ir
import tvm.runtime


def create_updater(node_map, from_ver, to_ver):
"""Create an updater to update json loaded data.
Expand All @@ -41,8 +44,12 @@ def _updater(data):
nodes = data["nodes"]
for idx, item in enumerate(nodes):
f = node_map.get(item["type_key"], None)
if f:
nodes[idx] = f(item, nodes)
if isinstance(f, list):
for fpass in f:
item = fpass(item, nodes)
elif f:
item = f(item, nodes)
nodes[idx] = item
data["attrs"]["tvm_version"] = to_ver
return data
return _updater
Expand Down Expand Up @@ -84,12 +91,26 @@ def _update_global_key(item, _):
del item["global_key"]
return item

def _update_from_std_str(key):
def _convert(item, nodes):
str_val = item["attrs"][key]
jdata = json.loads(tvm.ir.save_json(tvm.runtime.String(str_val)))
root_idx = jdata["root"]
val = jdata["nodes"][root_idx]
sidx = len(nodes)
nodes.append(val)
item["attrs"][key] = '%d' % sidx
return item

return _convert


node_map = {
# Base IR
"SourceName": _update_global_key,
"EnvFunc": _update_global_key,
"relay.Op": _update_global_key,
"relay.TypeVar": _ftype_var,
"relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")],
"relay.GlobalTypeVar": _ftype_var,
"relay.Type": _rename("Type"),
"relay.TupleType": _rename("TupleType"),
Expand Down
18 changes: 11 additions & 7 deletions src/ir/span.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@

namespace tvm {

ObjectPtr<Object> GetSourceNameNode(const std::string& name) {
ObjectPtr<Object> GetSourceNameNode(const String& name) {
// always return pointer as the reference can change as map re-allocate.
// or use another level of indirection by creating a unique_ptr
static std::unordered_map<std::string, ObjectPtr<SourceNameNode> > source_map;
static std::unordered_map<String, ObjectPtr<SourceNameNode> > source_map;

auto sn = source_map.find(name);
if (sn == source_map.end()) {
Expand All @@ -41,7 +41,11 @@ ObjectPtr<Object> GetSourceNameNode(const std::string& name) {
}
}

SourceName SourceName::Get(const std::string& name) {
ObjectPtr<Object> GetSourceNameNodeByStr(const std::string& name) {
return GetSourceNameNode(name);
}

SourceName SourceName::Get(const String& name) {
return SourceName(GetSourceNameNode(name));
}

Expand All @@ -55,10 +59,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});

TVM_REGISTER_NODE_TYPE(SourceNameNode)
.set_creator(GetSourceNameNode)
.set_repr_bytes([](const Object* n) {
return static_cast<const SourceNameNode*>(n)->name;
});
.set_creator(GetSourceNameNodeByStr)
.set_repr_bytes([](const Object* n) -> std::string {
return static_cast<const SourceNameNode*>(n)->name;
});

Span SpanNode::make(SourceName source, int lineno, int col_offset) {
auto n = make_object<SpanNode>();
Expand Down
4 changes: 2 additions & 2 deletions src/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});


TypeVar::TypeVar(std::string name, TypeKind kind) {
TypeVar::TypeVar(String name, TypeKind kind) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
n->name_hint = std::move(name);
n->kind = std::move(kind);
Expand All @@ -76,7 +76,7 @@ TypeVar::TypeVar(std::string name, TypeKind kind) {
TVM_REGISTER_NODE_TYPE(TypeVarNode);

TVM_REGISTER_GLOBAL("ir.TypeVar")
.set_body_typed([](std::string name, int kind) {
.set_body_typed([](String name, int kind) {
return TypeVar(name, static_cast<TypeKind>(kind));
});

Expand Down

0 comments on commit 08ea4ee

Please sign in to comment.