From 88520e132958f3b03ab7117a9c88d188918e30c3 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 11 Jul 2016 22:14:04 -0700 Subject: [PATCH] Enable use json for graph attr exchange (#5) --- nnvm/include/nnvm/c_api.h | 29 ++++++++++++++++++----------- nnvm/python/nnvm/base.py | 1 + nnvm/python/nnvm/graph.py | 29 ++++++++++++++++++++--------- nnvm/src/c_api/c_api_graph.cc | 26 ++++++++++++++++++-------- nnvm/src/pass/saveload_json.cc | 4 ++++ nnvm/tests/python/test_graph.py | 11 ++++++++++- 6 files changed, 71 insertions(+), 29 deletions(-) diff --git a/nnvm/include/nnvm/c_api.h b/nnvm/include/nnvm/c_api.h index d15ba59b46b2..46ecf2d510dc 100644 --- a/nnvm/include/nnvm/c_api.h +++ b/nnvm/include/nnvm/c_api.h @@ -248,27 +248,34 @@ NNVM_DLL int NNGraphFree(GraphHandle handle); */ NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol); /*! - * \brief Get Set a std::string typed attribute to graph. + * \brief Get Set a attribute in json format. + * This feature allows pass graph attributes back and forth in reasonable speed. + * * \param handle The graph handle. * \param key The key to the attribute. - * \param value The value to be exposed. + * \param json_value The value need to be in format [type_name, value], + * Where type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphSetStrAttr(GraphHandle handle, - const char* key, - const char* value); +NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, + const char* key, + const char* json_value); /*! - * \brief Get Set a std::string typed attribute from graph attribute. + * \brief Get a serialized attrirbute from graph. + * This feature allows pass graph attributes back and forth in reasonable speed. + * * \param handle The graph handle. * \param key The key to the attribute. - * \param out The result attribute, can be NULL if the attribute do not exist. + * \param json_out The result attribute, can be NULL if the attribute do not exist. + * The json_out is an array of [type_name, value]. + * Where the type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY. * \param success Whether the result is contained in out. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphGetStrAttr(SymbolHandle handle, - const char* key, - const char** out, - int *success); +NNVM_DLL int NNGraphGetJSONAttr(SymbolHandle handle, + const char* key, + const char** json_out, + int *success); /*! * \brief Apply pass on the src graph. * \param src The source graph handle. diff --git a/nnvm/python/nnvm/base.py b/nnvm/python/nnvm/base.py index bc3d3b7b6ecf..cf5ead2b4ab9 100644 --- a/nnvm/python/nnvm/base.py +++ b/nnvm/python/nnvm/base.py @@ -47,6 +47,7 @@ def _load_lib(): SymbolHandle = ctypes.c_void_p GraphHandle = ctypes.c_void_p + #---------------------------- # helper function definition #---------------------------- diff --git a/nnvm/python/nnvm/graph.py b/nnvm/python/nnvm/graph.py index c09f98173ea4..6661d9ca7140 100644 --- a/nnvm/python/nnvm/graph.py +++ b/nnvm/python/nnvm/graph.py @@ -5,12 +5,14 @@ import ctypes import sys +import json from .base import _LIB from .base import c_array, c_str, nn_uint, py_str, string_types from .base import GraphHandle, SymbolHandle from .base import check_call from .symbol import Symbol + class Graph(object): """Graph is the graph object that can be used to apply optimization pass. It contains additional graphwise attribute besides the internal symbol. @@ -31,7 +33,7 @@ def __init__(self, handle): def __del__(self): check_call(_LIB.NNGraphFree(self.handle)) - def attr(self, key): + def json_attr(self, key): """Get attribute string from the graph. Parameters @@ -46,24 +48,33 @@ def attr(self, key): """ ret = ctypes.c_char_p() success = ctypes.c_int() - check_call(_LIB.NNGraphGetStrAttr( + check_call(_LIB.NNGraphGetJSONAttr( self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success))) if success.value != 0: - return py_str(ret.value) + json_str = py_str(ret.value) + return json.loads(json_str)[1] else: return None - def _set_attr(self, **kwargs): + def _set_json_attr(self, key, value, type_name=None): """Set the attribute of the symbol. Parameters ---------- - **kwargs - The attributes to set + key : string + The key of the attribute + value : value + The any type that can be dumped to json + type_name : string + The typename registered on c++ side. """ - for k, v in kwargs.items(): - check_call(_LIB.NNGraphSetStrAttr( - self.handle, c_str(k), c_str(v))) + if isinstance(value, string_types): + type_name = 'str' + elif type_name is None: + raise ValueError("Need to specify type_name") + json_value = json.dumps([type_name, value]) + check_call(_LIB.NNGraphSetJSONAttr( + self.handle, c_str(key), c_str(json_value))) @property def symbol(self): diff --git a/nnvm/src/c_api/c_api_graph.cc b/nnvm/src/c_api/c_api_graph.cc index 3e8f573c2d84..c3de70618b1f 100644 --- a/nnvm/src/c_api/c_api_graph.cc +++ b/nnvm/src/c_api/c_api_graph.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include "./c_api_common.h" using namespace nnvm; @@ -34,26 +35,35 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) { API_END_HANDLE_ERROR(delete s); } -int NNGraphSetStrAttr(GraphHandle handle, - const char* key, - const char* value) { +int NNGraphSetJSONAttr(GraphHandle handle, + const char* key, + const char* json_value) { API_BEGIN(); Graph* g = static_cast(handle); - g->attrs[std::string(key)] = std::make_shared(std::string(value)); + std::string temp(json_value); + std::istringstream is(temp); + dmlc::JSONReader reader(&is); + nnvm::any value; + reader.Read(&value); + g->attrs[std::string(key)] = std::make_shared(std::move(value)); API_END(); } -int NNGraphGetStrAttr(GraphHandle handle, +int NNGraphGetJSONAttr(GraphHandle handle, const char* key, - const char** out, + const char** json_out, int *success) { + NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); Graph* g = static_cast(handle); std::string skey(key); auto it = g->attrs.find(skey); if (it != g->attrs.end()) { - const std::string& str = nnvm::get(*it->second.get()); - *out = str.c_str(); + std::ostringstream os; + dmlc::JSONWriter writer(&os); + writer.Write(*it->second.get()); + ret->ret_str = os.str(); + *json_out = (ret->ret_str).c_str(); *success = 1; } else { *success = 0; diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 309bb9336eb3..fee94fbf6eb9 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -203,5 +203,9 @@ NNVM_REGISTER_PASS(SaveJSON) .set_change_graph(true) .provide_graph_attr("json"); + +DMLC_JSON_ENABLE_ANY(std::string, str); +DMLC_JSON_ENABLE_ANY(std::vector, list_int); + } // namespace pass } // namespace nnvm diff --git a/nnvm/tests/python/test_graph.py b/nnvm/tests/python/test_graph.py index 160098ab52c9..8fa392db3425 100644 --- a/nnvm/tests/python/test_graph.py +++ b/nnvm/tests/python/test_graph.py @@ -6,9 +6,18 @@ def test_json_pass(): y = sym.conv2d(data=x, name='conv', stride=(2,2)) g = graph.create(y) ret = g.apply('SaveJSON') + ret._set_json_attr('json', ret.json_attr('json')) g2 = ret.apply('LoadJSON') - assert g2.apply('SaveJSON').attr('json') == ret.attr('json') + assert g2.apply('SaveJSON').json_attr('json') == ret.json_attr('json') + +def test_graph_json_attr(): + x = sym.Variable('x') + y = sym.conv2d(data=x, name='conv', stride=(2,2)) + g = graph.create(y) + g._set_json_attr('ilist', [1,2,3], 'list_int') + assert g.json_attr('ilist') == [1,2,3] if __name__ == "__main__": + test_graph_json_attr() test_json_pass()