From ac386a7249c48b7b0e6ab4af6b1b4408e86f98d1 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 9 Apr 2020 13:00:16 -0700 Subject: [PATCH] [NODE] General serialzation of leaf objects into bytes. This PR refactors the serialization mechanism to support general serialization of leaf objects into bytes. The new feature superceded the original GetGlobalKey feature for singletons. Added serialization support for runtime::String. --- include/tvm/node/reflection.h | 52 ++++++++------ python/tvm/ir/json_compact.py | 8 +++ src/ir/env_func.cc | 2 +- src/ir/op.cc | 2 +- src/ir/span.cc | 2 +- src/node/container.cc | 16 ++++- src/node/reflection.cc | 4 +- src/node/serialization.cc | 71 +++++++++++++++---- tests/python/relay/test_json_compact.py | 18 +++++ tests/python/unittest/test_node_reflection.py | 13 ++++ 10 files changed, 144 insertions(+), 44 deletions(-) diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 18dfa129fa39..9ed87df46618 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -98,17 +98,17 @@ class ReflectionVTable { typedef void (*FSHashReduce)(const Object* self, SHashReducer hash_reduce); /*! * \brief creator function. - * \param global_key Key that identifies a global single object. - * If this is not empty then FGlobalKey must be defined for the object. + * \param repr_bytes Repr bytes to create the object. + * If this is not empty then FReprBytes must be defined for the object. * \return The created function. */ - typedef ObjectPtr (*FCreate)(const std::string& global_key); + typedef ObjectPtr (*FCreate)(const std::string& repr_bytes); /*! - * \brief Global key function, only needed by global objects. + * \brief Function to get a byte representation that can be used to recover the object. * \param node The node pointer. - * \return node The global key to the node. + * \return bytes The bytes that can be used to recover the object. */ - typedef std::string (*FGlobalKey)(const Object* self); + typedef std::string (*FReprBytes)(const Object* self); /*! * \brief Dispatch the VisitAttrs function. * \param self The pointer to the object. @@ -116,11 +116,13 @@ class ReflectionVTable { */ inline void VisitAttrs(Object* self, AttrVisitor* visitor) const; /*! - * \brief Get global key of the object, if any. + * \brief Get repr bytes if any. * \param self The pointer to the object. - * \return the global key if object has one, otherwise return empty string. + * \param repr_bytes The output repr bytes, can be null, in which case the function + * simply queries if the ReprBytes function exists for the type. + * \return Whether repr bytes exists */ - inline std::string GetGlobalKey(Object* self) const; + inline bool GetReprBytes(const Object* self, std::string* repr_bytes) const; /*! * \brief Dispatch the SEqualReduce function. * \param self The pointer to the object. @@ -141,10 +143,10 @@ class ReflectionVTable { * by type_key and global key. * * \param type_key The type key of the object. - * \param global_key A global key that can be used to uniquely identify the object if any. + * \param repr_bytes Bytes representation of the object if any. */ TVM_DLL ObjectPtr CreateInitObject(const std::string& type_key, - const std::string& global_key = "") const; + const std::string& repr_bytes = "") const; /*! * \brief Get an field object by the attr name. * \param self The pointer to the object. @@ -176,8 +178,8 @@ class ReflectionVTable { std::vector fshash_reduce_; /*! \brief Creation function. */ std::vector fcreate_; - /*! \brief Global key function. */ - std::vector fglobal_key_; + /*! \brief ReprBytes function. */ + std::vector frepr_bytes_; }; /*! \brief Registry of a reflection table. */ @@ -196,13 +198,13 @@ class ReflectionVTable::Registry { return *this; } /*! - * \brief Set global_key function. - * \param f The creator function. + * \brief Set bytes repr function. + * \param f The ReprBytes function. * \return rference to self. */ - Registry& set_global_key(FGlobalKey f) { // NOLINT(*) - CHECK_LT(type_index_, parent_->fglobal_key_.size()); - parent_->fglobal_key_[type_index_] = f; + Registry& set_repr_bytes(FReprBytes f) { // NOLINT(*) + CHECK_LT(type_index_, parent_->frepr_bytes_.size()); + parent_->frepr_bytes_[type_index_] = f; return *this; } @@ -365,7 +367,7 @@ ReflectionVTable::Register() { if (tindex >= fvisit_attrs_.size()) { fvisit_attrs_.resize(tindex + 1, nullptr); fcreate_.resize(tindex + 1, nullptr); - fglobal_key_.resize(tindex + 1, nullptr); + frepr_bytes_.resize(tindex + 1, nullptr); fsequal_reduce_.resize(tindex + 1, nullptr); fshash_reduce_.resize(tindex + 1, nullptr); } @@ -392,12 +394,16 @@ VisitAttrs(Object* self, AttrVisitor* visitor) const { fvisit_attrs_[tindex](self, visitor); } -inline std::string ReflectionVTable::GetGlobalKey(Object* self) const { +inline bool ReflectionVTable::GetReprBytes(const Object* self, + std::string* repr_bytes) const { uint32_t tindex = self->type_index(); - if (tindex < fglobal_key_.size() && fglobal_key_[tindex] != nullptr) { - return fglobal_key_[tindex](self); + if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) { + if (repr_bytes != nullptr) { + *repr_bytes = frepr_bytes_[tindex](self); + } + return true; } else { - return std::string(); + return false; } } diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index aa43df5a6697..e091cd12a208 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -79,8 +79,16 @@ def _convert(item, _): return item return _convert + def _update_global_key(item, _): + item["repr_str"] = item["global_key"] + del item["global_key"] + return item + node_map = { # Base IR + "SourceName": _update_global_key, + "EnvFunc": _update_global_key, + "relay.Op": _update_global_key, "relay.TypeVar": _ftype_var, "relay.GlobalTypeVar": _ftype_var, "relay.Type": _rename("Type"), diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index 3e85c5f47b52..4d3ed30bc032 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -69,7 +69,7 @@ TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc") TVM_REGISTER_NODE_TYPE(EnvFuncNode) .set_creator(CreateEnvNode) -.set_global_key([](const Object* n) -> std::string { +.set_repr_bytes([](const Object* n) -> std::string { return static_cast(n)->name; }); diff --git a/src/ir/op.cc b/src/ir/op.cc index 54374eb8a526..6a50240ee7a1 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -223,7 +223,7 @@ ObjectPtr CreateOp(const std::string& name) { TVM_REGISTER_NODE_TYPE(OpNode) .set_creator(CreateOp) -.set_global_key([](const Object* n) { +.set_repr_bytes([](const Object* n) { return static_cast(n)->name; }); diff --git a/src/ir/span.cc b/src/ir/span.cc index d03903c2d3a5..f84353de2a8b 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -56,7 +56,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(SourceNameNode) .set_creator(GetSourceNameNode) -.set_global_key([](const Object* n) { +.set_repr_bytes([](const Object* n) { return static_cast(n)->name; }); diff --git a/src/node/container.cc b/src/node/container.cc index 8fff151ce605..e7e497946b6f 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -48,7 +48,21 @@ struct StringObjTrait { } }; -TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait); +struct RefToObjectPtr : public ObjectRef { + static ObjectPtr Get(const ObjectRef& ref) { + return GetDataPtr(ref); + } +}; + +TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait) +.set_creator([](const std::string& bytes) { + return RefToObjectPtr::Get(runtime::String(bytes)); +}) +.set_repr_bytes([](const Object* n) -> std::string { + return GetRef( + static_cast(n)).operator std::string(); +}); + struct ADTObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 824874f24ab0..08a914ff38f9 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -178,13 +178,13 @@ ReflectionVTable* ReflectionVTable::Global() { ObjectPtr ReflectionVTable::CreateInitObject(const std::string& type_key, - const std::string& global_key) const { + const std::string& repr_bytes) const { uint32_t tindex = Object::TypeKey2Index(type_key); if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) { LOG(FATAL) << "TypeError: " << type_key << " is not registered via TVM_REGISTER_NODE_TYPE"; } - return fcreate_[tindex](global_key); + return fcreate_[tindex](repr_bytes); } class NodeAttrSetter : public AttrVisitor { diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 11c9e8fc8cb6..ee6072d77c1c 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -32,6 +32,7 @@ #include #include +#include #include #include "../support/base64.h" @@ -46,6 +47,26 @@ inline DataType String2Type(std::string s) { return DataType(runtime::String2DLDataType(s)); } +inline std::string Base64Decode(std::string s) { + dmlc::MemoryStringStream mstrm(&s); + support::Base64InStream b64strm(&mstrm); + std::string output; + b64strm.InitPosition(); + dmlc::Stream* strm = &b64strm; + strm->Read(&output); + return output; +} + +inline std::string Base64Encode(std::string s) { + std::string blob; + dmlc::MemoryStringStream mstrm(&blob); + support::Base64OutStream b64strm(&mstrm); + dmlc::Stream* strm = &b64strm; + strm->Write(s); + b64strm.Finish(); + return blob; +} + // indexer to index all the nodes class NodeIndexer : public AttrVisitor { public: @@ -103,7 +124,10 @@ class NodeIndexer : public AttrVisitor { MakeIndex(const_cast(kv.second.get())); } } else { - reflection_->VisitAttrs(node, this); + // if the node already have repr bytes, no need to visit Attrs. + if (!reflection_->GetReprBytes(node, nullptr)) { + reflection_->VisitAttrs(node, this); + } } } }; @@ -115,8 +139,8 @@ using AttrMap = std::map; struct JSONNode { /*! \brief The type of key of the object. */ std::string type_key; - /*! \brief The global key for global object. */ - std::string global_key; + /*! \brief The str repr representation. */ + std::string repr_bytes; /*! \brief the attributes */ AttrMap attrs; /*! \brief keys of a map. */ @@ -127,8 +151,15 @@ struct JSONNode { void Save(dmlc::JSONWriter *writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("type_key", type_key); - if (global_key.size() != 0) { - writer->WriteObjectKeyValue("global_key", global_key); + if (repr_bytes.size() != 0) { + // choose to use str representation or base64, based on whether + // the byte representation is printable. + if (std::all_of(repr_bytes.begin(), repr_bytes.end(), + [](char ch) { return std::isprint(ch); })) { + writer->WriteObjectKeyValue("repr_str", repr_bytes); + } else { + writer->WriteObjectKeyValue("repr_b64", Base64Encode(repr_bytes)); + } } if (attrs.size() != 0) { writer->WriteObjectKeyValue("attrs", attrs); @@ -145,15 +176,24 @@ struct JSONNode { void Load(dmlc::JSONReader *reader) { attrs.clear(); data.clear(); - global_key.clear(); + repr_bytes.clear(); type_key.clear(); + std::string repr_b64, repr_str; dmlc::JSONObjectReadHelper helper; helper.DeclareOptionalField("type_key", &type_key); - helper.DeclareOptionalField("global_key", &global_key); + helper.DeclareOptionalField("repr_b64", &repr_b64); + helper.DeclareOptionalField("repr_str", &repr_str); helper.DeclareOptionalField("attrs", &attrs); helper.DeclareOptionalField("keys", &keys); helper.DeclareOptionalField("data", &data); helper.ReadAllFields(reader); + + if (repr_str.size() != 0) { + CHECK_EQ(repr_b64.size(), 0U); + repr_bytes = std::move(repr_str); + } else if (repr_b64.size() != 0) { + repr_bytes = Base64Decode(repr_b64); + } } }; @@ -212,10 +252,8 @@ class JSONAttrGetter : public AttrVisitor { return; } node_->type_key = node->GetTypeKey(); - node_->global_key = reflection_->GetGlobalKey(node); - // No need to recursively visit fields of global singleton - // They are registered via the environment. - if (node_->global_key.length() != 0) return; + // do not need to print additional things once we have repr bytes. + if (reflection_->GetReprBytes(node, &(node_->repr_bytes))) return; // populates the fields. node_->attrs.clear(); @@ -434,7 +472,7 @@ ObjectRef LoadJSON(std::string json_str) { for (const JSONNode& jnode : jgraph.nodes) { if (jnode.type_key.length() != 0) { ObjectPtr node = - reflection->CreateInitObject(jnode.type_key, jnode.global_key); + reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes); nodes.emplace_back(node); } else { nodes.emplace_back(ObjectPtr()); @@ -447,9 +485,12 @@ ObjectRef LoadJSON(std::string json_str) { for (size_t i = 0; i < nodes.size(); ++i) { setter.node_ = &jgraph.nodes[i]; - // do not need to recover content of global singleton object - // they are registered via the environment - if (setter.node_->global_key.length() == 0) { + // Skip the nodes that has an repr bytes representation. + // NOTE: the second condition is used to guard the case + // where the repr bytes itself is an empty string "". + if (setter.node_->repr_bytes.length() == 0 && + nodes[i] != nullptr && + !reflection->GetReprBytes(nodes[i].get(), nullptr)) { setter.Set(nodes[i].get()); } } diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 54812be62d9b..16d02d2cc224 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -16,6 +16,7 @@ # under the License. import tvm +from tvm import relay from tvm import te import json @@ -108,6 +109,22 @@ def test_global_var(): assert isinstance(tvar, tvm.ir.GlobalVar) +def test_op(): + nodes = [ + {"type_key": ""}, + {"type_key": "relay.Op", + "global_key": "nn.conv2d"} + ] + data = { + "root" : 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + op = tvm.ir.load_json(json.dumps(data)) + assert op == relay.op.get("nn.conv2d") + + def test_tir_var(): nodes = [ {"type_key": ""}, @@ -132,6 +149,7 @@ def test_tir_var(): if __name__ == "__main__": + test_op() test_type_var() test_incomplete_type() test_func_tuple_type() diff --git a/tests/python/unittest/test_node_reflection.py b/tests/python/unittest/test_node_reflection.py index f2848ff0ef50..975192293d87 100644 --- a/tests/python/unittest/test_node_reflection.py +++ b/tests/python/unittest/test_node_reflection.py @@ -89,7 +89,20 @@ def test(x): assert x.func(10) == 11 +def test_string(): + # non printable str, need to store by b64 + s1 = tvm.runtime.String("xy\x01z") + s2 = tvm.ir.load_json(tvm.ir.save_json(s1)) + tvm.ir.assert_structural_equal(s1, s2) + + # printable str, need to store by repr_str + s1 = tvm.runtime.String("xyz") + s2 = tvm.ir.load_json(tvm.ir.save_json(s1)) + tvm.ir.assert_structural_equal(s1, s2) + + if __name__ == "__main__": + test_string() test_env_func() test_make_node() test_make_smap()