diff --git a/include/nnvm/c_api.h b/include/nnvm/c_api.h index f3ae626f3afc..028c88d98154 100644 --- a/include/nnvm/c_api.h +++ b/include/nnvm/c_api.h @@ -266,6 +266,14 @@ NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol, */ NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol, SymbolHandle *out); +/*! + * \brief Get a symbol that contains only direct children. + * \param symbol The symbol + * \param out The output symbol whose outputs are the direct children. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol, + SymbolHandle *out); /*! * \brief Get index-th outputs of the symbol. * \param symbol The symbol diff --git a/python/nnvm/symbol.py b/python/nnvm/symbol.py index bca63a3656f2..0ac51f23fe94 100644 --- a/python/nnvm/symbol.py +++ b/python/nnvm/symbol.py @@ -143,6 +143,9 @@ def __getitem__(self, index): self.handle, _base.nn_uint(index), _ctypes.byref(handle))) return Symbol(handle=handle) + def __iter__(self): + return (self[i] for i in self.list_output_names()) + def attr(self, key): """Get attribute string from the symbol, this function only works for non-grouped symbol. @@ -196,6 +199,17 @@ def get_internals(self): self.handle, _ctypes.byref(handle))) return Symbol(handle=handle) + def get_children(self): + """Gets a new grouped symbol whose output contains + inputs to output nodes of the original symbol.""" + handle = _base.SymbolHandle() + _check_call(_LIB.NNSymbolGetChildren( + self.handle, _ctypes.byref(handle))) + ret = Symbol(handle=handle) + if not ret.list_output_names(): + return None + return ret + def _get_list_copt(self, option): """internal function to get list option""" if option == 'all': diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index feb12fbc6d65..0327ca4ae8a7 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -141,6 +141,15 @@ int NNSymbolGetInternals(SymbolHandle symbol, API_END_HANDLE_ERROR(delete s); } +int NNSymbolGetChildren(SymbolHandle symbol, + SymbolHandle *out) { + Symbol *s = new Symbol(); + API_BEGIN(); + *s = static_cast(symbol)->GetChildren(); + *out = s; + API_END_HANDLE_ERROR(delete s); +} + int NNSymbolFree(SymbolHandle symbol) { API_BEGIN(); delete static_cast(symbol);