diff --git a/nnvm/Makefile b/nnvm/Makefile index 84040be98292c..feba8d278183e 100644 --- a/nnvm/Makefile +++ b/nnvm/Makefile @@ -1,9 +1,9 @@ export LDFLAGS = -pthread -lm export CFLAGS = -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loops\ - -Iinclude -Idmlc-core/include -I../include -fPIC -L../lib + -Iinclude -Idmlc-core/include -fPIC # specify tensor path -.PHONY: clean all test lint doc python +.PHONY: clean all test lint doc cython cython3 all: lib/libnnvm.so lib/libnnvm.a cli_test @@ -31,9 +31,13 @@ lib/libnnvm.a: $(ALL_DEP) cli_test: $(ALL_DEP) build/test_main.o $(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS) -python: +cython: cd python; python setup.py build_ext --inplace +cython3: + cd python; python3 setup.py build_ext --inplace + + lint: python2 dmlc-core/scripts/lint.py nnvm cpp include src diff --git a/nnvm/python/nnvm/__init__.py b/nnvm/python/nnvm/__init__.py index 38f943b9fd6c0..d30b1c152e2bd 100644 --- a/nnvm/python/nnvm/__init__.py +++ b/nnvm/python/nnvm/__init__.py @@ -1,10 +1,11 @@ #!/usr/bin/env python # coding: utf-8 """NNVM python API for ease of use and help new framework establish python API. """ -from __future__ import absolute_import +from __future__ import absolute_import as _abs -from . import base +from . import _base from . import symbol as sym from . import symbol +from ._base import NNVMError -__version__ = base.__version__ +__version__ = _base.__version__ diff --git a/nnvm/python/nnvm/base.py b/nnvm/python/nnvm/_base.py similarity index 98% rename from nnvm/python/nnvm/base.py rename to nnvm/python/nnvm/_base.py index cf5ead2b4ab91..825a3d380f38b 100644 --- a/nnvm/python/nnvm/base.py +++ b/nnvm/python/nnvm/_base.py @@ -4,6 +4,7 @@ from __future__ import absolute_import import sys +import os import ctypes import numpy as np from . import libinfo @@ -31,7 +32,7 @@ class NNVMError(Exception): def _load_lib(): """Load libary by searching possible path.""" lib_path = libinfo.find_lib_path() - lib = ctypes.cdll.LoadLibrary(lib_path[0]) + lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL) # DMatrix functions lib.NNGetLastError.restype = ctypes.c_char_p return lib @@ -41,13 +42,13 @@ def _load_lib(): # library instance of nnvm _LIB = _load_lib() + # type definitions nn_uint = ctypes.c_uint SymbolCreatorHandle = ctypes.c_void_p SymbolHandle = ctypes.c_void_p GraphHandle = ctypes.c_void_p - #---------------------------- # helper function definition #---------------------------- diff --git a/nnvm/python/nnvm/_cy2/README b/nnvm/python/nnvm/_cy2/README new file mode 100644 index 0000000000000..ed4639b674a05 --- /dev/null +++ b/nnvm/python/nnvm/_cy2/README @@ -0,0 +1 @@ +This folder is by default empty and will hold DLLs generated by cython. diff --git a/nnvm/python/nnvm/_cy2/__init__.py b/nnvm/python/nnvm/_cy2/__init__.py new file mode 100644 index 0000000000000..910cbe2e586b4 --- /dev/null +++ b/nnvm/python/nnvm/_cy2/__init__.py @@ -0,0 +1 @@ +"""Namespace for cython generated modules for python2""" diff --git a/nnvm/python/nnvm/_cy3/README b/nnvm/python/nnvm/_cy3/README new file mode 100644 index 0000000000000..dc3a576037829 --- /dev/null +++ b/nnvm/python/nnvm/_cy3/README @@ -0,0 +1 @@ +This folder is by default empty and will hold DLLs generated by cython. \ No newline at end of file diff --git a/nnvm/python/nnvm/_cy3/__init__.py b/nnvm/python/nnvm/_cy3/__init__.py new file mode 100644 index 0000000000000..c3eb41421b3ed --- /dev/null +++ b/nnvm/python/nnvm/_cy3/__init__.py @@ -0,0 +1 @@ +"""Cython generated modules""" diff --git a/nnvm/python/nnvm/attribute.py b/nnvm/python/nnvm/attribute.py index b60ea01504929..a023b9cd88df1 100644 --- a/nnvm/python/nnvm/attribute.py +++ b/nnvm/python/nnvm/attribute.py @@ -2,7 +2,7 @@ """Attribute scoping support for symbolic API.""" from __future__ import absolute_import -from .base import string_types +from ._base import string_types class AttrScope(object): """Attribute manager for scoping. @@ -59,4 +59,3 @@ def __exit__(self, ptype, value, trace): AttrScope.current = self._old_scope AttrScope.current = AttrScope() - diff --git a/nnvm/python/nnvm/ctypes/README b/nnvm/python/nnvm/ctypes/README new file mode 100644 index 0000000000000..6e82cb962f992 --- /dev/null +++ b/nnvm/python/nnvm/ctypes/README @@ -0,0 +1 @@ +Ctypes specific implementation of certain modules \ No newline at end of file diff --git a/nnvm/python/nnvm/ctypes/__init__.py b/nnvm/python/nnvm/ctypes/__init__.py new file mode 100644 index 0000000000000..fc76dabf682b4 --- /dev/null +++ b/nnvm/python/nnvm/ctypes/__init__.py @@ -0,0 +1 @@ +""""ctypes implementation of the Symbol""" diff --git a/nnvm/python/nnvm/ctypes/symbol.py b/nnvm/python/nnvm/ctypes/symbol.py new file mode 100644 index 0000000000000..503fc09bde745 --- /dev/null +++ b/nnvm/python/nnvm/ctypes/symbol.py @@ -0,0 +1,385 @@ +# coding: utf-8 +# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines +"""Symbolic configuration API.""" +from __future__ import absolute_import as _abs + +import copy +import ctypes +import sys +from .._base import _LIB +from .._base import c_array, c_str, nn_uint, py_str, string_types +from .._base import SymbolHandle +from .._base import check_call, ctypes2docstring +from ..name import NameManager +from ..attribute import AttrScope + +__all__ = ["Symbol", "Variable"] + +class Symbol(object): + """Symbol is symbolic graph.""" + + # pylint: disable=no-member + def __init__(self, handle): + """Initialize the function with handle + + Parameters + ---------- + handle : SymbolHandle + the handle to the underlying C++ Symbol + """ + self.handle = handle + + def __del__(self): + check_call(_LIB.NNSymbolFree(self.handle)) + + def __copy__(self): + return copy.deepcopy(self) + + def __deepcopy__(self, _): + handle = SymbolHandle() + check_call(_LIB.NNSymbolCopy(self.handle, + ctypes.byref(handle))) + return Symbol(handle) + + def __call__(self, *args, **kwargs): + """Invoke symbol as function on inputs. + + Parameters + ---------- + args: + provide positional arguments + + kwargs: + provide keyword arguments + Returns + ------- + the resulting symbol + """ + s = copy.deepcopy(self) + s._compose(*args, **kwargs) + return s + + def _compose(self, *args, **kwargs): + """Compose symbol on inputs. + + This call mutates the current symbol. + + Parameters + ---------- + args: + provide positional arguments + + kwargs: + provide keyword arguments + + Returns + ------- + the resulting symbol + """ + name = kwargs.pop('name', None) + + if name: + name = c_str(name) + if len(args) != 0 and len(kwargs) != 0: + raise TypeError('compose only accept input Symbols \ + either as positional or keyword arguments, not both') + + for arg in args: + if not isinstance(arg, Symbol): + raise TypeError('Compose expect `Symbol` as arguments') + for val in kwargs.values(): + if not isinstance(val, Symbol): + raise TypeError('Compose expect `Symbol` as arguments') + + num_args = len(args) + len(kwargs) + if len(kwargs) != 0: + keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) + args = c_array(SymbolHandle, [s.handle for s in kwargs.values()]) + else: + keys = None + args = c_array(SymbolHandle, [s.handle for s in args]) + check_call(_LIB.NNSymbolCompose( + self.handle, name, num_args, keys, args)) + + def __getitem__(self, index): + if isinstance(index, string_types): + idx = None + for i, name in enumerate(self.list_outputs()): + if name == index: + if idx is not None: + raise ValueError('There are multiple outputs with name \"%s\"' % index) + idx = i + if idx is None: + raise ValueError('Cannot find output that matches name \"%s\"' % index) + index = idx + if not isinstance(index, int): + raise TypeError('Symbol only support integer index to fetch i-th output') + handle = SymbolHandle() + check_call(_LIB.NNSymbolGetOutput( + self.handle, nn_uint(index), ctypes.byref(handle))) + return Symbol(handle=handle) + + def attr(self, key): + """Get attribute string from the symbol, this function only works for non-grouped symbol. + + Parameters + ---------- + key : str + The key to get attribute from. + + Returns + ------- + value : str + The attribute value of the key, returns None if attribute do not exist. + """ + ret = ctypes.c_char_p() + success = ctypes.c_int() + check_call(_LIB.NNSymbolGetAttr( + self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success))) + if success.value != 0: + return py_str(ret.value) + else: + return None + + def list_attr(self, recursive=False): + """Get all attributes from the symbol. + + Parameters + ---------- + recursive : bool + Default `False`. When `recursive` is `True`, list recursively all the + attributes in the descendents. The attribute names are pre-pended with + the symbol names to avoid conflicts. If `False`, then only attributes + that belongs to this symbol is returned, and the attribute names will + **not** be pre-pended with the symbol name. + """ + size = nn_uint() + pairs = ctypes.POINTER(ctypes.c_char_p)() + option = ctypes.c_int(0) if recursive else ctypes.c_int(1) + check_call(_LIB.NNSymbolListAttrs( + self.handle, option, ctypes.byref(size), ctypes.byref(pairs))) + return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size.value)} + + def _set_attr(self, **kwargs): + """Set the attribute of the symbol. + + Parameters + ---------- + **kwargs + The attributes to set + """ + keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) + vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()]) + num_args = nn_uint(len(kwargs)) + check_call(_LIB.NNSymbolSetAttrs( + self.handle, num_args, keys, vals)) + + def get_internals(self): + """Get a new grouped symbol whose output contains all the internal outputs of this symbol. + + Returns + ------- + sgroup : Symbol + The internal of the symbol. + """ + handle = SymbolHandle() + check_call(_LIB.NNSymbolGetInternals( + self.handle, ctypes.byref(handle))) + return Symbol(handle=handle) + + def list_arguments(self): + """List all the arguments in the symbol. + + Returns + ------- + args : list of string + List of all the arguments. + """ + size = ctypes.c_uint() + sarr = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.NNSymbolListArguments( + self.handle, ctypes.byref(size), ctypes.byref(sarr))) + return [py_str(sarr[i]) for i in range(size.value)] + + def list_outputs(self): + """List all outputs in the symbol. + + Returns + ------- + returns : list of string + List of all the outputs. + """ + size = ctypes.c_uint() + sarr = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.NNSymbolListOutputs( + self.handle, ctypes.byref(size), ctypes.byref(sarr))) + return [py_str(sarr[i]) for i in range(size.value)] + + def debug_str(self): + """Get a debug string. + + Returns + ------- + debug_str : string + Debug string of the symbol. + """ + debug_str = ctypes.c_char_p() + check_call(_LIB.NNSymbolPrint( + self.handle, ctypes.byref(debug_str))) + return py_str(debug_str.value) + + +def Variable(name, **kwargs): + """Create a symbolic variable with specified name. + + Parameters + ---------- + name : str + Name of the variable. + kwargs : dict of string -> string + Additional attributes to set on the variable. + + Returns + ------- + variable : Symbol + The created variable symbol. + """ + if not isinstance(name, string_types): + raise TypeError('Expect a string for variable `name`') + handle = SymbolHandle() + check_call(_LIB.NNSymbolCreateVariable(c_str(name), ctypes.byref(handle))) + ret = Symbol(handle) + attr = AttrScope.current.get(kwargs) + if attr: + ret._set_attr(**attr) + return ret + + +def Group(symbols): + """Create a symbol that groups symbols together. + + Parameters + ---------- + symbols : list + List of symbols to be grouped. + + Returns + ------- + sym : Symbol + The created group symbol. + """ + ihandles = [] + for sym in symbols: + if not isinstance(sym, Symbol): + raise TypeError('Expect Symbols in the list input') + ihandles.append(sym.handle) + handle = SymbolHandle() + check_call(_LIB.NNSymbolCreateGroup( + nn_uint(len(ihandles)), + c_array(SymbolHandle, ihandles), ctypes.byref(handle))) + return Symbol(handle) + + +def _make_atomic_symbol_function(handle): + """Create an atomic symbol function by handle and funciton name.""" + name = ctypes.c_char_p() + desc = ctypes.c_char_p() + num_args = nn_uint() + arg_names = ctypes.POINTER(ctypes.c_char_p)() + arg_types = ctypes.POINTER(ctypes.c_char_p)() + arg_descs = ctypes.POINTER(ctypes.c_char_p)() + ret_type = ctypes.c_char_p() + + check_call(_LIB.NNSymbolGetAtomicSymbolInfo( + handle, ctypes.byref(name), ctypes.byref(desc), + ctypes.byref(num_args), + ctypes.byref(arg_names), + ctypes.byref(arg_types), + ctypes.byref(arg_descs), + ctypes.byref(ret_type))) + param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs) + func_name = py_str(name.value) + desc = py_str(desc.value) + + doc_str = ('%s\n\n' + + '%s\n' + + 'name : string, optional.\n' + + ' Name of the resulting symbol.\n\n' + + 'Returns\n' + + '-------\n' + + 'symbol: Symbol\n' + + ' The result symbol.') + doc_str = doc_str % (desc, param_str) + + def creator(*args, **kwargs): + """Activation Operator of Neural Net. + The parameters listed below can be passed in as keyword arguments. + + Parameters + ---------- + name : string, required. + Name of the resulting symbol. + + Returns + ------- + symbol: Symbol + the resulting symbol + """ + param_keys = [] + param_vals = [] + symbol_kwargs = {} + name = kwargs.pop('name', None) + attr = kwargs.pop('attr', None) + + for k, v in kwargs.items(): + if isinstance(v, Symbol): + symbol_kwargs[k] = v + else: + param_keys.append(c_str(k)) + param_vals.append(c_str(str(v))) + # create atomic symbol + param_keys = c_array(ctypes.c_char_p, param_keys) + param_vals = c_array(ctypes.c_char_p, param_vals) + sym_handle = SymbolHandle() + check_call(_LIB.NNSymbolCreateAtomicSymbol( + handle, + nn_uint(len(param_keys)), + param_keys, param_vals, + ctypes.byref(sym_handle))) + + if len(args) != 0 and len(symbol_kwargs) != 0: + raise TypeError( + '%s can only accept input' + 'Symbols either as positional or keyword arguments, not both' % func_name) + s = Symbol(sym_handle) + attr = AttrScope.current.get(attr) + if attr: + s._set_attr(**attr) + hint = func_name.lower() + name = NameManager.current.get(name, hint) + s._compose(*args, name=name, **symbol_kwargs) + return s + + creator.__name__ = func_name + creator.__doc__ = doc_str + return creator + + +def _init_symbol_module(): + """List and add all the atomic symbol functions to current module.""" + plist = ctypes.POINTER(ctypes.c_void_p)() + size = ctypes.c_uint() + + check_call(_LIB.NNSymbolListAtomicSymbolCreators(ctypes.byref(size), + ctypes.byref(plist))) + module_obj = sys.modules["nnvm.symbol"] + for i in range(size.value): + hdl = SymbolHandle(plist[i]) + function = _make_atomic_symbol_function(hdl) + if function.__name__.startswith('_'): + setattr(Symbol, function.__name__, staticmethod(function)) + else: + setattr(module_obj, function.__name__, function) + +# Initialize the atomic symbol in startups +_init_symbol_module() diff --git a/nnvm/python/nnvm/cython/README b/nnvm/python/nnvm/cython/README new file mode 100644 index 0000000000000..d9deab1abca9a --- /dev/null +++ b/nnvm/python/nnvm/cython/README @@ -0,0 +1 @@ +Cython specific implementation of certain modules \ No newline at end of file diff --git a/nnvm/python/nnvm/cython/base.pyi b/nnvm/python/nnvm/cython/base.pyi new file mode 100644 index 0000000000000..f9175651dea94 --- /dev/null +++ b/nnvm/python/nnvm/cython/base.pyi @@ -0,0 +1,67 @@ +ctypedef void* SymbolHandle +ctypedef void* AtomicSymbolCreator +ctypedef unsigned nn_uint + +cdef py_str(const char* x): + if PY_MAJOR_VERSION < 3: + return x + else: + return x.decode("utf-8") + + +cdef CALL(int ret): + if ret != 0: + raise NNVMError(NNGetLastError()) + + +cdef const char** CBeginPtr(vector[const char*]& vec): + if (vec.size() != 0): + return &vec[0] + else: + return NULL + + +cdef BuildDoc(nn_uint num_args, + const char** arg_names, + const char** arg_types, + const char** arg_descs, + remove_dup=True): + """Convert ctypes returned doc string information into parameters docstring. + + num_args : nn_uint + Number of arguments. + + arg_names : ctypes.POINTER(ctypes.c_char_p) + Argument names. + + arg_types : ctypes.POINTER(ctypes.c_char_p) + Argument type information. + + arg_descs : ctypes.POINTER(ctypes.c_char_p) + Argument description information. + + remove_dup : boolean, optional + Whether remove duplication or not. + + Returns + ------- + docstr : str + Python docstring of parameter sections. + """ + param_keys = set() + param_str = [] + for i in range(num_args): + key = arg_names[i] + if key in param_keys and remove_dup: + continue + param_keys.add(key) + type_info = arg_types[i] + ret = '%s : %s' % (key, type_info) + if len(arg_descs[i]) != 0: + ret += '\n ' + arg_descs[i] + param_str.append(ret) + doc_str = ('Parameters\n' + + '----------\n' + + '%s\n') + doc_str = doc_str % ('\n'.join(param_str)) + return doc_str diff --git a/nnvm/python/nnvm/symbolx.pyd b/nnvm/python/nnvm/cython/symbol.pyd similarity index 100% rename from nnvm/python/nnvm/symbolx.pyd rename to nnvm/python/nnvm/cython/symbol.pyd diff --git a/nnvm/python/nnvm/cython/symbol.pyx b/nnvm/python/nnvm/cython/symbol.pyx new file mode 100644 index 0000000000000..5554520cdf5ba --- /dev/null +++ b/nnvm/python/nnvm/cython/symbol.pyx @@ -0,0 +1,376 @@ +from __future__ import absolute_import as _abs + +import sys as _sys +import ctypes as _ctypes +from .._base import NNVMError +from ..name import NameManager +from ..attribute import AttrScope +from libcpp.vector cimport vector +from cpython.version cimport PY_MAJOR_VERSION + +include "./base.pyi" + +cdef extern from "nnvm/c_api.h": + const char* NNGetLastError(); + int NNSymbolCreateVariable(const char *name, SymbolHandle *out); + int NNSymbolCreateGroup(nn_uint num_symbols, + SymbolHandle *symbols, + SymbolHandle *out); + int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, + AtomicSymbolCreator **out_array); + int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, + nn_uint num_param, + const char **keys, + const char **vals, + SymbolHandle *out); + int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, + const char **name, + const char **description, + nn_uint *num_doc_args, + const char ***arg_names, + const char ***arg_type_infos, + const char ***arg_descriptions, + const char **return_type); + int NNSymbolFree(SymbolHandle symbol); + int NNSymbolPrint(SymbolHandle symbol, const char **out_str); + int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out); + int NNSymbolGetAttr(SymbolHandle symbol, + const char* key, + const char** out, + int *success); + int NNSymbolSetAttrs(SymbolHandle symbol, + nn_uint num_param, + const char** keys, + const char** values); + int NNSymbolListAttrs(SymbolHandle symbol, + int recursive_option, + nn_uint *out_size, + const char*** out); + int NNSymbolListArguments(SymbolHandle symbol, + nn_uint *out_size, + const char ***out_str_array); + int NNSymbolListOutputs(SymbolHandle symbol, + nn_uint *out_size, + const char ***out_str_array); + int NNSymbolGetInternals(SymbolHandle symbol, + SymbolHandle *out); + int NNSymbolGetOutput(SymbolHandle symbol, + nn_uint index, + SymbolHandle *out); + int NNSymbolCompose(SymbolHandle sym, + const char* name, + nn_uint num_args, + const char** keys, + SymbolHandle* args); + + +cdef class Symbol: + """Symbol is symbolic graph.""" + # handle for symbolic operator. + cdef SymbolHandle handle + + def __init__(self, handle): + cdef unsigned long ptr + if handle is None: + self.handle = NULL + else: + ptr = handle.value + self.handle = (ptr) + + def __dealloc__(self): + CALL(NNSymbolFree(self.handle)) + + @property + def handle(self): + return _ctypes.cast(self.handle, _ctypes.c_void_p) + + def __copy__(self): + return self.__deepcopy__() + + def __deepcopy__(self, _ = None): + cdef SymbolHandle handle + CALL(NNSymbolCopy(self.handle, &handle)) + return NewSymbol(handle) + + def __getitem__(self, index): + if isinstance(index, str): + idx = None + for i, name in enumerate(self.list_outputs()): + if name == index: + if idx is not None: + raise ValueError('There are multiple outputs with name \"%s\"' % index) + idx = i + if idx is None: + raise ValueError('Cannot find output that matches name \"%s\"' % index) + index = idx + if not isinstance(index, int): + raise TypeError('Symbol only support integer index to fetch i-th output') + cdef SymbolHandle handle + cdef nn_uint c_index = index + CALL(NNSymbolGetOutput(self.handle, c_index, &handle)) + return NewSymbol(handle) + + def attr(self, const char* key): + """Get attribute string from the symbol, this function only works for non-grouped symbol. + + Parameters + ---------- + key : str + The key to get attribute from. + + Returns + ------- + value : str + The attribute value of the key, returns None if attribute do not exist. + """ + cdef const char* ret + cdef int success + CALL(NNSymbolGetAttr( + self.handle, key, &ret, &success)) + if success != 0: + return py_str(ret.value) + else: + return None + + def list_attr(self, recursive=False): + """Get all attributes from the symbol. + + Parameters + ---------- + recursive : bool + Default `False`. When `recursive` is `True`, list recursively all the + attributes in the descendents. The attribute names are pre-pended with + the symbol names to avoid conflicts. If `False`, then only attributes + that belongs to this symbol is returned, and the attribute names will + **not** be pre-pended with the symbol name. + """ + cdef nn_uint size + cdef const char** pairs + cdef int option + option = 0 if recursive else 1 + CALL(NNSymbolListAttrs( + self.handle, option, &size, &pairs)) + return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size)} + + def _set_attr(self, **kwargs): + """Set the attribute of the symbol. + + Parameters + ---------- + **kwargs + The attributes to set + """ + SymbolSetAttr(self.handle, kwargs) + + def get_internals(self): + """Get a new grouped symbol whose output contains all the internal outputs of this symbol. + + Returns + ------- + sgroup : Symbol + The internal of the symbol. + """ + cdef SymbolHandle handle + CALL(NNSymbolGetInternals(self.handle, &handle)) + return NewSymbol(handle) + + def list_arguments(self): + """List all the arguments in the symbol. + + Returns + ------- + args : list of string + List of all the arguments. + """ + cdef nn_uint size + cdef const char ** sarr + CALL(NNSymbolListArguments(self.handle, &size, &sarr)) + return [py_str(sarr[i]) for i in range(size)] + + def list_outputs(self): + """List all outputs in the symbol. + + Returns + ------- + returns : list of string + List of all the outputs. + """ + cdef nn_uint size + cdef const char ** sarr + CALL(NNSymbolListOutputs(self.handle, &size, &sarr)) + return [py_str(sarr[i]) for i in range(size)] + + def debug_str(self): + cdef const char* out_str + CALL(NNSymbolPrint(self.handle, &out_str)) + return str(out_str) + + +cdef SymbolSetAttr(SymbolHandle handle, dict kwargs): + cdef vector[const char*] param_keys + cdef vector[const char*] param_vals + cdef nn_uint num_args + for k, v in kwargs.items(): + param_keys.push_back(k) + param_vals.push_back(str(v)) + num_args = param_keys.size() + CALL(NNSymbolSetAttrs( + handle, num_args, CBeginPtr(param_keys), CBeginPtr(param_vals))) + + +cdef NewSymbol(SymbolHandle handle): + """Create a new symbol given handle""" + sym = Symbol(None) + sym.handle = handle + return sym + + +def Variable(const char* name, **kwargs): + """Create a symbolic variable with specified name. + + Parameters + ---------- + name : str + Name of the variable. + kwargs : dict of string -> string + Additional attributes to set on the variable. + + Returns + ------- + variable : Symbol + The created variable symbol. + """ + cdef SymbolHandle handle + CALL(NNSymbolCreateVariable(name, &handle)) + return NewSymbol(handle) + + +cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): + """Create an atomic symbol function by handle and funciton name.""" + cdef const char *name + cdef const char *desc + cdef nn_uint num_args + cdef const char** arg_names + cdef const char** arg_types + cdef const char** arg_descs + cdef const char* return_type + + CALL(NNSymbolGetAtomicSymbolInfo( + handle, &name, &desc, + &num_args, &arg_names, + &arg_types, &arg_descs, + &return_type)) + param_str = BuildDoc(num_args, arg_names, arg_types, arg_descs) + func_name = py_str(name) + doc_str = ('%s\n\n' + + '%s\n' + + 'name : string, optional.\n' + + ' Name of the resulting symbol.\n\n' + + 'Returns\n' + + '-------\n' + + 'symbol: Symbol\n' + + ' The result symbol.') + doc_str = doc_str % (desc, param_str) + func_hint = func_name.lower() + + def creator(*args, **kwargs): + cdef vector[const char*] param_keys + cdef vector[const char*] param_vals + cdef vector[SymbolHandle] symbol_args + cdef vector[const char*] symbol_keys + cdef SymbolHandle ret_handle + + name = kwargs.pop("name", None) + attr = kwargs.pop("attr", None) + + if len(kwargs) != 0: + for k, v in kwargs.items(): + if isinstance(v, Symbol): + symbol_keys.push_back(k) + symbol_args.push_back((v).handle) + else: + param_keys.push_back(k) + param_vals.push_back(str(v)) + + if len(args) != 0: + if symbol_args.size() != 0: + raise TypeError("compose only accept input Symbols\ + either as positional or keyword arguments, not both") + for v in args: + if not isinstance(v, Symbol): + raise TypeError('Compose expect `Symbol` as arguments') + symbol_args.push_back((v).handle) + + CALL(NNSymbolCreateAtomicSymbol( + handle, + param_keys.size(), + CBeginPtr(param_keys), + CBeginPtr(param_vals), + &ret_handle)) + num_args = (symbol_args.size()) + + attr = AttrScope.current.get(attr) + if attr: + SymbolSetAttr(ret_handle, attr) + name = NameManager.current.get(name, func_hint) + + cdef const char* c_name = NULL + if name: + c_name = name + + CALL(NNSymbolCompose( + ret_handle, + c_name, + num_args, + &symbol_keys[0] if symbol_keys.size() != 0 else NULL, + &symbol_args[0] if symbol_args.size() != 0 else NULL)) + return NewSymbol(ret_handle) + + creator.__name__ = func_name + creator.__doc__ = doc_str + return creator + + +def Group(symbols): + """Create a symbol that groups symbols together. + + Parameters + ---------- + symbols : list + List of symbols to be grouped. + + Returns + ------- + sym : Symbol + The created group symbol. + """ + cdef vector[SymbolHandle] ihandles + cdef SymbolHandle handle + + for sym in symbols: + if not isinstance(sym, Symbol): + raise TypeError("Expect Symbols in the list input") + ihandles.push_back((sym).handle) + if ihandles.size() == 0: + raise ValueError("expect at least one element in the input") + CALL(NNSymbolCreateGroup(ihandles.size(), + &ihandles[0], &handle)) + return NewSymbol(handle) + + +def _init_symbol_module(): + """List and add all the atomic symbol functions to current module.""" + cdef AtomicSymbolCreator* plist + cdef nn_uint size + CALL(NNSymbolListAtomicSymbolCreators(&size, &plist)) + module_obj = _sys.modules["nnvm.symbol"] + for i in range(size): + function = _make_atomic_symbol_function(plist[i]) + if function.__name__.startswith('_'): + setattr(Symbol, function.__name__, staticmethod(function)) + else: + setattr(module_obj, function.__name__, function) + + +# Initialize the atomic symbol in startups +_init_symbol_module() diff --git a/nnvm/python/nnvm/graph.py b/nnvm/python/nnvm/graph.py index 6661d9ca71400..3f184928a9ec8 100644 --- a/nnvm/python/nnvm/graph.py +++ b/nnvm/python/nnvm/graph.py @@ -6,10 +6,10 @@ import ctypes import sys import json -from .base import _LIB -from .base import c_array, c_str, nn_uint, py_str, string_types -from .base import GraphHandle, SymbolHandle -from .base import check_call +from ._base import _LIB +from ._base import c_array, c_str, nn_uint, py_str, string_types +from ._base import GraphHandle, SymbolHandle +from ._base import check_call from .symbol import Symbol diff --git a/nnvm/python/nnvm/name.py b/nnvm/python/nnvm/name.py index 7e50e9e5070d8..081d2bae7242b 100644 --- a/nnvm/python/nnvm/name.py +++ b/nnvm/python/nnvm/name.py @@ -1,6 +1,6 @@ # coding: utf-8 """Automatic naming support for symbolic API.""" -from __future__ import absolute_import +from __future__ import absolute_import as _abs class NameManager(object): """NameManager to do automatic naming. diff --git a/nnvm/python/nnvm/symbol.py b/nnvm/python/nnvm/symbol.py index bf9b57ab42056..ac1dee39a08bd 100644 --- a/nnvm/python/nnvm/symbol.py +++ b/nnvm/python/nnvm/symbol.py @@ -1,385 +1,14 @@ -# coding: utf-8 -# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines """Symbolic configuration API.""" from __future__ import absolute_import as _abs - -import copy -import ctypes -import sys -from .base import _LIB -from .base import c_array, c_str, nn_uint, py_str, string_types -from .base import SymbolHandle -from .base import check_call, ctypes2docstring -from .name import NameManager -from .attribute import AttrScope - -__all__ = ["Symbol", "Variable"] - -class Symbol(object): - """Symbol is symbolic graph.""" - - # pylint: disable=no-member - def __init__(self, handle): - """Initialize the function with handle - - Parameters - ---------- - handle : SymbolHandle - the handle to the underlying C++ Symbol - """ - self.handle = handle - - def __del__(self): - check_call(_LIB.NNSymbolFree(self.handle)) - - def __copy__(self): - return copy.deepcopy(self) - - def __deepcopy__(self, _): - handle = SymbolHandle() - check_call(_LIB.NNSymbolCopy(self.handle, - ctypes.byref(handle))) - return Symbol(handle) - - def __call__(self, *args, **kwargs): - """Invoke symbol as function on inputs. - - Parameters - ---------- - args: - provide positional arguments - - kwargs: - provide keyword arguments - Returns - ------- - the resulting symbol - """ - s = copy.deepcopy(self) - s._compose(*args, **kwargs) - return s - - def _compose(self, *args, **kwargs): - """Compose symbol on inputs. - - This call mutates the current symbol. - - Parameters - ---------- - args: - provide positional arguments - - kwargs: - provide keyword arguments - - Returns - ------- - the resulting symbol - """ - name = kwargs.pop('name', None) - - if name: - name = c_str(name) - if len(args) != 0 and len(kwargs) != 0: - raise TypeError('compose only accept input Symbols \ - either as positional or keyword arguments, not both') - - for arg in args: - if not isinstance(arg, Symbol): - raise TypeError('Compose expect `Symbol` as arguments') - for val in kwargs.values(): - if not isinstance(val, Symbol): - raise TypeError('Compose expect `Symbol` as arguments') - - num_args = len(args) + len(kwargs) - if len(kwargs) != 0: - keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) - args = c_array(SymbolHandle, [s.handle for s in kwargs.values()]) - else: - keys = None - args = c_array(SymbolHandle, [s.handle for s in args]) - check_call(_LIB.NNSymbolCompose( - self.handle, name, num_args, keys, args)) - - def __getitem__(self, index): - if isinstance(index, string_types): - idx = None - for i, name in enumerate(self.list_outputs()): - if name == index: - if idx is not None: - raise ValueError('There are multiple outputs with name \"%s\"' % index) - idx = i - if idx is None: - raise ValueError('Cannot find output that matches name \"%s\"' % index) - index = idx - if not isinstance(index, int): - raise TypeError('Symbol only support integer index to fetch i-th output') - handle = SymbolHandle() - check_call(_LIB.NNSymbolGetOutput( - self.handle, nn_uint(index), ctypes.byref(handle))) - return Symbol(handle=handle) - - def attr(self, key): - """Get attribute string from the symbol, this function only works for non-grouped symbol. - - Parameters - ---------- - key : str - The key to get attribute from. - - Returns - ------- - value : str - The attribute value of the key, returns None if attribute do not exist. - """ - ret = ctypes.c_char_p() - success = ctypes.c_int() - check_call(_LIB.NNSymbolGetAttr( - self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success))) - if success.value != 0: - return py_str(ret.value) - else: - return None - - def list_attr(self, recursive=False): - """Get all attributes from the symbol. - - Parameters - ---------- - recursive : bool - Default `False`. When `recursive` is `True`, list recursively all the - attributes in the descendents. The attribute names are pre-pended with - the symbol names to avoid conflicts. If `False`, then only attributes - that belongs to this symbol is returned, and the attribute names will - **not** be pre-pended with the symbol name. - """ - size = nn_uint() - pairs = ctypes.POINTER(ctypes.c_char_p)() - option = ctypes.c_int(0) if recursive else ctypes.c_int(1) - check_call(_LIB.NNSymbolListAttrs( - self.handle, option, ctypes.byref(size), ctypes.byref(pairs))) - return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size.value)} - - def _set_attr(self, **kwargs): - """Set the attribute of the symbol. - - Parameters - ---------- - **kwargs - The attributes to set - """ - keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) - vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()]) - num_args = nn_uint(len(kwargs)) - check_call(_LIB.NNSymbolSetAttrs( - self.handle, num_args, keys, vals)) - - def get_internals(self): - """Get a new grouped symbol whose output contains all the internal outputs of this symbol. - - Returns - ------- - sgroup : Symbol - The internal of the symbol. - """ - handle = SymbolHandle() - check_call(_LIB.NNSymbolGetInternals( - self.handle, ctypes.byref(handle))) - return Symbol(handle=handle) - - def list_arguments(self): - """List all the arguments in the symbol. - - Returns - ------- - args : list of string - List of all the arguments. - """ - size = ctypes.c_uint() - sarr = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.NNSymbolListArguments( - self.handle, ctypes.byref(size), ctypes.byref(sarr))) - return [py_str(sarr[i]) for i in range(size.value)] - - def list_outputs(self): - """List all outputs in the symbol. - - Returns - ------- - returns : list of string - List of all the outputs. - """ - size = ctypes.c_uint() - sarr = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.NNSymbolListOutputs( - self.handle, ctypes.byref(size), ctypes.byref(sarr))) - return [py_str(sarr[i]) for i in range(size.value)] - - def debug_str(self): - """Get a debug string. - - Returns - ------- - debug_str : string - Debug string of the symbol. - """ - debug_str = ctypes.c_char_p() - check_call(_LIB.NNSymbolPrint( - self.handle, ctypes.byref(debug_str))) - return py_str(debug_str.value) - - -def Variable(name, **kwargs): - """Create a symbolic variable with specified name. - - Parameters - ---------- - name : str - Name of the variable. - kwargs : dict of string -> string - Additional attributes to set on the variable. - - Returns - ------- - variable : Symbol - The created variable symbol. - """ - if not isinstance(name, string_types): - raise TypeError('Expect a string for variable `name`') - handle = SymbolHandle() - check_call(_LIB.NNSymbolCreateVariable(c_str(name), ctypes.byref(handle))) - ret = Symbol(handle) - attr = AttrScope.current.get(kwargs) - if attr: - ret._set_attr(**attr) - return ret - - -def Group(symbols): - """Create a symbol that groups symbols together. - - Parameters - ---------- - symbols : list - List of symbols to be grouped. - - Returns - ------- - sym : Symbol - The created group symbol. - """ - ihandles = [] - for sym in symbols: - if not isinstance(sym, Symbol): - raise TypeError('Expect Symbols in the list input') - ihandles.append(sym.handle) - handle = SymbolHandle() - check_call(_LIB.NNSymbolCreateGroup( - nn_uint(len(ihandles)), - c_array(SymbolHandle, ihandles), ctypes.byref(handle))) - return Symbol(handle) - - -def _make_atomic_symbol_function(handle): - """Create an atomic symbol function by handle and funciton name.""" - name = ctypes.c_char_p() - desc = ctypes.c_char_p() - num_args = nn_uint() - arg_names = ctypes.POINTER(ctypes.c_char_p)() - arg_types = ctypes.POINTER(ctypes.c_char_p)() - arg_descs = ctypes.POINTER(ctypes.c_char_p)() - ret_type = ctypes.c_char_p() - - check_call(_LIB.NNSymbolGetAtomicSymbolInfo( - handle, ctypes.byref(name), ctypes.byref(desc), - ctypes.byref(num_args), - ctypes.byref(arg_names), - ctypes.byref(arg_types), - ctypes.byref(arg_descs), - ctypes.byref(ret_type))) - param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs) - func_name = py_str(name.value) - desc = py_str(desc.value) - - doc_str = ('%s\n\n' + - '%s\n' + - 'name : string, optional.\n' + - ' Name of the resulting symbol.\n\n' + - 'Returns\n' + - '-------\n' + - 'symbol: Symbol\n' + - ' The result symbol.') - doc_str = doc_str % (desc, param_str) - - def creator(*args, **kwargs): - """Activation Operator of Neural Net. - The parameters listed below can be passed in as keyword arguments. - - Parameters - ---------- - name : string, required. - Name of the resulting symbol. - - Returns - ------- - symbol: Symbol - the resulting symbol - """ - param_keys = [] - param_vals = [] - symbol_kwargs = {} - name = kwargs.pop('name', None) - attr = kwargs.pop('attr', None) - - for k, v in kwargs.items(): - if isinstance(v, Symbol): - symbol_kwargs[k] = v - else: - param_keys.append(c_str(k)) - param_vals.append(c_str(str(v))) - # create atomic symbol - param_keys = c_array(ctypes.c_char_p, param_keys) - param_vals = c_array(ctypes.c_char_p, param_vals) - sym_handle = SymbolHandle() - check_call(_LIB.NNSymbolCreateAtomicSymbol( - handle, - nn_uint(len(param_keys)), - param_keys, param_vals, - ctypes.byref(sym_handle))) - - if len(args) != 0 and len(symbol_kwargs) != 0: - raise TypeError( - '%s can only accept input' - 'Symbols either as positional or keyword arguments, not both' % func_name) - s = Symbol(sym_handle) - attr = AttrScope.current.get(attr) - if attr: - s._set_attr(**attr) - hint = func_name.lower() - name = NameManager.current.get(name, hint) - s._compose(*args, name=name, **symbol_kwargs) - return s - - creator.__name__ = func_name - creator.__doc__ = doc_str - return creator - - -def _init_symbol_module(): - """List and add all the atomic symbol functions to current module.""" - plist = ctypes.POINTER(ctypes.c_void_p)() - size = ctypes.c_uint() - - check_call(_LIB.NNSymbolListAtomicSymbolCreators(ctypes.byref(size), - ctypes.byref(plist))) - module_obj = sys.modules[__name__] - for i in range(size.value): - hdl = SymbolHandle(plist[i]) - function = _make_atomic_symbol_function(hdl) - if function.__name__.startswith('_'): - setattr(Symbol, function.__name__, staticmethod(function)) - else: - setattr(module_obj, function.__name__, function) - -# Initialize the atomic symbol in startups -_init_symbol_module() +import sys as _sys +import os as _os + +try: + if int(_os.environ.get("NNVM_ENABLE_CYTHON", True)) == 0: + from .ctypes.symbol import Symbol, Variable + elif _sys.version_info >= (3, 0): + from ._cy3.symbol import Symbol, Variable, Group + else: + from ._cy2.symbol import Symbol, Variable, Group +except: + from .ctypes.symbol import Symbol, Variable, Group diff --git a/nnvm/python/nnvm/symbolx.pyx b/nnvm/python/nnvm/symbolx.pyx deleted file mode 100644 index 5264a391fcf55..0000000000000 --- a/nnvm/python/nnvm/symbolx.pyx +++ /dev/null @@ -1,220 +0,0 @@ -import sys -from libcpp.vector cimport vector - -ctypedef void* SymbolHandle -ctypedef void* AtomicSymbolCreator -ctypedef unsigned nn_uint - -cdef extern from "nnvm/c_api.h": - int NNSymbolFree(SymbolHandle symbol) - int NNSymbolCreateVariable(const char *name, SymbolHandle *out) - const char* NNGetLastError() - int NNSymbolPrint(SymbolHandle symbol, const char **out_str) - int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, - AtomicSymbolCreator **out_array); - int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, - nn_uint num_param, - const char **keys, - const char **vals, - SymbolHandle *out); - int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, - const char **name, - const char **description, - nn_uint *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type); - int NNSymbolCompose(SymbolHandle sym, - const char* name, - nn_uint num_args, - const char** keys, - SymbolHandle* args); - -cdef CALL(int ret): - if ret != 0: - raise RuntimeError(NNGetLastError()) - -cdef const char** CBeginPtr(vector[const char*]& vec): - if (vec.size() != 0): - return &vec[0] - else: - return NULL - -cdef ctypes2docstring(nn_uint num_args, - const char** arg_names, - const char** arg_types, - const char** arg_descs, - remove_dup=True): - """Convert ctypes returned doc string information into parameters docstring. - - num_args : nn_uint - Number of arguments. - - arg_names : ctypes.POINTER(ctypes.c_char_p) - Argument names. - - arg_types : ctypes.POINTER(ctypes.c_char_p) - Argument type information. - - arg_descs : ctypes.POINTER(ctypes.c_char_p) - Argument description information. - - remove_dup : boolean, optional - Whether remove duplication or not. - - Returns - ------- - docstr : str - Python docstring of parameter sections. - """ - param_keys = set() - param_str = [] - for i in range(num_args): - key = arg_names[i] - if key in param_keys and remove_dup: - continue - param_keys.add(key) - type_info = arg_types[i] - ret = '%s : %s' % (key, type_info) - if len(arg_descs[i]) != 0: - ret += '\n ' + arg_descs[i] - param_str.append(ret) - doc_str = ('Parameters\n' + - '----------\n' + - '%s\n') - doc_str = doc_str % ('\n'.join(param_str)) - return doc_str - - -cdef class Symbol: - # handle for symbolic operator. - cdef SymbolHandle handle - - def __dealloc__(self): - CALL(NNSymbolFree(self.handle)) - - def debug_str(self): - cdef const char* out_str - CALL(NNSymbolPrint(self.handle, &out_str)) - return str(out_str) - -cdef NewSymbol(SymbolHandle handle): - """Create a new symbol given handle""" - sym = Symbol() - sym.handle = handle - return sym - - -def Variable(const char* name, **kwargs): - """Create a symbolic variable with specified name. - - Parameters - ---------- - name : str - Name of the variable. - kwargs : dict of string -> string - Additional attributes to set on the variable. - - Returns - ------- - variable : Symbol - The created variable symbol. - """ - cdef SymbolHandle handle - CALL(NNSymbolCreateVariable(name, &handle)) - return NewSymbol(handle) - - -cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): - """Create an atomic symbol function by handle and funciton name.""" - cdef const char *name - cdef const char *desc - cdef nn_uint num_args - cdef const char** arg_names - cdef const char** arg_types - cdef const char** arg_descs - cdef const char* return_type - - CALL(NNSymbolGetAtomicSymbolInfo( - handle, &name, &desc, - &num_args, &arg_names, - &arg_types, &arg_descs, - &return_type)) - - param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs) - func_name = name - doc_str = ('%s\n\n' + - '%s\n' + - 'name : string, optional.\n' + - ' Name of the resulting symbol.\n\n' + - 'Returns\n' + - '-------\n' + - 'symbol: Symbol\n' + - ' The result symbol.') - doc_str = doc_str % (desc, param_str) - - def creator(*args, **kwargs): - cdef vector[const char*] param_keys - cdef vector[const char*] param_vals - cdef vector[SymbolHandle] symbol_args - cdef vector[const char*] symbol_keys - cdef SymbolHandle ret_handle - cdef const char* c_name = NULL - - name = kwargs.pop('name', None) - attr = kwargs.pop('attr', None) - if name: - c_name = name - - if len(kwargs) != 0: - for k, v in kwargs.items(): - if isinstance(v, Symbol): - symbol_keys.push_back(k) - symbol_args.push_back((v).handle) - else: - param_keys.push_back(k) - param_vals.push_back(str(v)) - - if len(args) != 0: - if symbol_args.size() != 0: - raise TypeError("compose only accept input Symbols\ - either as positional or keyword arguments, not both") - for v in args: - if not isinstance(v, Symbol): - raise TypeError('Compose expect `Symbol` as arguments') - symbol_args.push_back((v).handle) - - CALL(NNSymbolCreateAtomicSymbol( - handle, - param_keys.size(), - CBeginPtr(param_keys), - CBeginPtr(param_vals), - &ret_handle)) - num_args = (symbol_args.size()) - CALL(NNSymbolCompose( - ret_handle, c_name, num_args, - &symbol_keys[0] if symbol_keys.size() != 0 else NULL, - &symbol_args[0] if symbol_args.size() != 0 else NULL)) - return NewSymbol(ret_handle) - - creator.__name__ = func_name - creator.__doc__ = doc_str - return creator - - -def _init_symbol_module(): - """List and add all the atomic symbol functions to current module.""" - cdef AtomicSymbolCreator* plist - cdef nn_uint size - CALL(NNSymbolListAtomicSymbolCreators(&size, &plist)) - module_obj = sys.modules[__name__] - for i in range(size): - function = _make_atomic_symbol_function(plist[i]) - if function.__name__.startswith('_'): - setattr(Symbol, function.__name__, staticmethod(function)) - else: - setattr(module_obj, function.__name__, function) - -# Initialize the atomic symbol in startups -_init_symbol_module() diff --git a/nnvm/python/setup.py b/nnvm/python/setup.py index e50d1aaccccb0..58b6e7efe088b 100644 --- a/nnvm/python/setup.py +++ b/nnvm/python/setup.py @@ -1,13 +1,29 @@ +import os +import sys from distutils.core import setup from Cython.Build import cythonize from distutils.extension import Extension + +def config(): + if sys.version_info >= (3, 0): + subdir = "_cy3" + else: + subdir = "_cy2" + ret = [] + path = "nnvm/cython" + + for fn in os.listdir(path): + if not fn.endswith(".pyx"): + continue + ret.append(Extension( + "nnvm/%s/%s" % (subdir, fn[:-4]), + ["nnvm/cython/%s" % fn], + include_dirs=["../include/"], + language="c++")) + return ret + setup( name='nnvm', - ext_modules = cythonize([ - Extension("nnvm/symbolx", - ["nnvm/symbolx.pyx"], - libraries=["nnvm"], - language="c++") - ]) + ext_modules = cythonize(config()) ) diff --git a/nnvm/tests/python/test_symbol.py b/nnvm/tests/python/test_symbol.py index 8259862152ae4..adc9099adc134 100644 --- a/nnvm/tests/python/test_symbol.py +++ b/nnvm/tests/python/test_symbol.py @@ -1,5 +1,5 @@ import nnvm.symbol as sym -from nnvm.base import NNVMError +from nnvm import NNVMError def test_compose(): x = sym.Variable('x')