Skip to content

Commit

Permalink
[numpy] fix cython (apache#15418)
Browse files Browse the repository at this point in the history
* add cython support for numpy

* stay with original API for backward compatibility
  • Loading branch information
hzfan authored and haojin2 committed Jul 22, 2019
1 parent 5994ccc commit d45c7cc
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 26 deletions.
18 changes: 6 additions & 12 deletions ci/jenkins/Jenkins_steps.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ def compile_unix_cpu_openblas() {
timeout(time: max_time, unit: 'MINUTES') {
utils.init_git()
utils.docker_run('ubuntu_cpu', 'build_ubuntu_cpu_openblas', false)
// utils.pack_lib('cpu', mx_lib_cython, true)
utils.pack_lib('cpu', mx_lib, true)
utils.pack_lib('cpu', mx_lib_cython, true)
}
}
}
Expand Down Expand Up @@ -267,8 +266,7 @@ def compile_unix_cmake_gpu() {
timeout(time: max_time, unit: 'MINUTES') {
utils.init_git()
utils.docker_run('ubuntu_gpu_cu101', 'build_ubuntu_gpu_cmake', false)
// utils.pack_lib('cmake_gpu', mx_cmake_lib_cython, true)
utils.pack_lib('cmake_gpu', mx_cmake_lib, true)
utils.pack_lib('cmake_gpu', mx_cmake_lib_cython, true)
}
}
}
Expand Down Expand Up @@ -645,10 +643,8 @@ def test_unix_python2_cpu() {
node(NODE_LINUX_CPU) {
ws('workspace/ut-python2-cpu') {
try {
// utils.unpack_and_init('cpu', mx_lib_cython, true)
// python2_ut_cython('ubuntu_cpu')
utils.unpack_and_init('cpu', mx_lib, true)
python2_ut('ubuntu_cpu')
utils.unpack_and_init('cpu', mx_lib_cython, true)
python2_ut_cython('ubuntu_cpu')
utils.publish_test_coverage()
} finally {
utils.collect_test_results_unix('nosetests_unittest.xml', 'nosetests_python2_cpu_unittest.xml')
Expand Down Expand Up @@ -749,10 +745,8 @@ def test_unix_python3_gpu() {
node(NODE_LINUX_GPU) {
ws('workspace/ut-python3-gpu') {
try {
// utils.unpack_and_init('gpu', mx_lib_cython, true)
// python3_gpu_ut_cython('ubuntu_gpu_cu100')
utils.unpack_and_init('gpu', mx_lib, true)
python3_gpu_ut('ubuntu_gpu_cu101')
utils.unpack_and_init('gpu', mx_lib_cython, true)
python3_gpu_ut_cython('ubuntu_gpu_cu101')
utils.publish_test_coverage()
} finally {
utils.collect_test_results_unix('nosetests_gpu.xml', 'nosetests_python3_gpu.xml')
Expand Down
4 changes: 2 additions & 2 deletions ci/jenkins/Jenkinsfile_unix_cpu
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ core_logic: {
custom_steps.test_unix_python3_mkldnn_mkl_cpu(),
custom_steps.test_unix_scala_cpu(),
custom_steps.test_unix_scala_mkldnn_cpu(),
// custom_steps.test_unix_clojure_cpu(),
// custom_steps.test_unix_clojure_integration_cpu(),
custom_steps.test_unix_clojure_cpu(),
custom_steps.test_unix_clojure_integration_cpu(),
custom_steps.test_unix_perl_cpu(),
custom_steps.test_unix_r_cpu(),
custom_steps.test_unix_r_mkldnn_cpu(),
Expand Down
27 changes: 19 additions & 8 deletions python/mxnet/cython/ndarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,27 @@ cdef class NDArrayBase:


_ndarray_cls = None
_np_ndarray_cls = None

def _set_ndarray_class(cls):
global _ndarray_cls
_ndarray_cls = cls


cdef NewArray(NDArrayHandle handle, int stype=-1):
def _set_np_ndarray_class(cls):
global _np_ndarray_cls
_np_ndarray_cls = cls


cdef NewArray(NDArrayHandle handle, int stype=-1, int is_np_array=0):
"""Create a new array given handle"""
return _ndarray_cls(_ctypes.cast(<unsigned long long>handle, _ctypes.c_void_p), stype=stype)
create_array_fn = _np_ndarray_cls if is_np_array else _ndarray_cls
return create_array_fn(_ctypes.cast(<unsigned long long>handle, _ctypes.c_void_p), stype=stype)


cdef class CachedOp:
"""Cached operator handle."""
cdef CachedOpHandle chandle

cdef _set_handle(self, handle):
cdef unsigned long long ptr
if handle is None:
Expand All @@ -96,6 +102,8 @@ cdef class CachedOp:
def __set__(self, value):
self._set_handle(value)

cdef int is_np_sym

def __init__(self, sym, flags=()):
cdef vector[string] s_flag_keys
cdef vector[string] s_flag_vals
Expand All @@ -106,6 +114,9 @@ cdef class CachedOp:
cdef vector[const char*] c_flag_keys = SVec2Ptr(s_flag_keys)
cdef vector[const char*] c_flag_vals = SVec2Ptr(s_flag_vals)

from ..symbol.numpy._symbol import _Symbol
self.is_np_sym = bool(isinstance(sym, _Symbol))

CALL(MXCreateCachedOpEx(
<SymbolHandle>(<unsigned long long>sym.handle.value),
len(flags),
Expand Down Expand Up @@ -154,12 +165,12 @@ cdef class CachedOp:
if original_output is not None:
return original_output
if num_output == 1:
return NewArray(p_output_vars[0], p_output_stypes[0])
return NewArray(p_output_vars[0], p_output_stypes[0], self.is_np_sym)
else:
return [NewArray(p_output_vars[i], p_output_stypes[i]) for i in range(num_output)]
return [NewArray(p_output_vars[i], p_output_stypes[i], self.is_np_sym) for i in range(num_output)]


def _imperative_invoke(handle, ndargs, keys, vals, out):
def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op=0):
"""cython implementation of imperative invoke wrapper"""
cdef unsigned long long ihandle = handle
cdef OpHandle chandle = <OpHandle>ihandle
Expand Down Expand Up @@ -211,6 +222,6 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
if original_output is not None:
return original_output
if num_output == 1:
return NewArray(p_output_vars[0], p_output_stypes[0])
return NewArray(p_output_vars[0], p_output_stypes[0], is_np_op)
else:
return [NewArray(p_output_vars[i], p_output_stypes[i]) for i in range(num_output)]
return [NewArray(p_output_vars[i], p_output_stypes[i], is_np_op) for i in range(num_output)]
16 changes: 12 additions & 4 deletions python/mxnet/cython/symbol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,27 @@ cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):


_symbol_cls = SymbolBase
_np_symbol_cls = None

def _set_symbol_class(cls):
global _symbol_cls
_symbol_cls = cls

cdef NewSymbol(SymbolHandle handle):

def _set_np_symbol_class(cls):
global _np_symbol_cls
_np_symbol_cls = cls


cdef NewSymbol(SymbolHandle handle, int is_np_sym=0):
"""Create a new symbol given handle"""
sym = _symbol_cls(None)
create_symbol_fn = _np_symbol_cls if is_np_sym else _symbol_cls
sym = create_symbol_fn(None)
(<SymbolBase>sym).chandle = handle
return sym


def _symbol_creator(handle, args, kwargs, keys, vals, name):
def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op=0):
cdef unsigned long long ihandle = handle
cdef OpHandle chandle = <OpHandle>ihandle
cdef vector[string] ckeys
Expand Down Expand Up @@ -143,4 +151,4 @@ def _symbol_creator(handle, args, kwargs, keys, vals, name):
&csym_keys[0] if csym_keys.size() != 0 else NULL,
&sym_args[0] if sym_args.size() != 0 else NULL))

return NewSymbol(ret_handle)
return NewSymbol(ret_handle, is_np_op)

0 comments on commit d45c7cc

Please sign in to comment.