From c06b2874484390f04ed79c2a12c24cdd219bfe4c Mon Sep 17 00:00:00 2001 From: Fan Date: Tue, 2 Jul 2019 10:20:35 +0800 Subject: [PATCH] stay with original API for backward compatibility --- python/mxnet/cython/ndarray.pyx | 18 ++++++++---------- python/mxnet/cython/symbol.pyx | 10 ++++------ 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/python/mxnet/cython/ndarray.pyx b/python/mxnet/cython/ndarray.pyx index 6a066badc2e2..50791e9b9a86 100644 --- a/python/mxnet/cython/ndarray.pyx +++ b/python/mxnet/cython/ndarray.pyx @@ -76,12 +76,10 @@ def _set_np_ndarray_class(cls): _np_ndarray_cls = cls -cdef NewArray(NDArrayHandle handle, int is_np_op, int stype=-1): +cdef NewArray(NDArrayHandle handle, int stype=-1, int is_np_array=0): """Create a new array given handle""" - if is_np_op: - return _np_ndarray_cls(_ctypes.cast(handle, _ctypes.c_void_p), stype=stype) - else: - return _ndarray_cls(_ctypes.cast(handle, _ctypes.c_void_p), stype=stype) + create_array_fn = _np_ndarray_cls if is_np_array else _ndarray_cls + return create_array_fn(_ctypes.cast(handle, _ctypes.c_void_p), stype=stype) cdef class CachedOp: @@ -167,12 +165,12 @@ cdef class CachedOp: if original_output is not None: return original_output if num_output == 1: - return NewArray(p_output_vars[0], self.is_np_sym, p_output_stypes[0]) + return NewArray(p_output_vars[0], p_output_stypes[0], self.is_np_sym) else: - return [NewArray(p_output_vars[i], self.is_np_sym, p_output_stypes[i]) for i in range(num_output)] + return [NewArray(p_output_vars[i], p_output_stypes[i], self.is_np_sym) for i in range(num_output)] -def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op): +def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op=0): """cython implementation of imperative invoke wrapper""" cdef unsigned long long ihandle = handle cdef OpHandle chandle = ihandle @@ -224,6 +222,6 @@ def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op): if original_output is not None: return original_output if num_output == 1: - return NewArray(p_output_vars[0], is_np_op, p_output_stypes[0]) + return NewArray(p_output_vars[0], p_output_stypes[0], is_np_op) else: - return [NewArray(p_output_vars[i], is_np_op, p_output_stypes[i]) for i in range(num_output)] + return [NewArray(p_output_vars[i], p_output_stypes[i], is_np_op) for i in range(num_output)] diff --git a/python/mxnet/cython/symbol.pyx b/python/mxnet/cython/symbol.pyx index bbb609002318..86fe8ae6db4f 100644 --- a/python/mxnet/cython/symbol.pyx +++ b/python/mxnet/cython/symbol.pyx @@ -96,17 +96,15 @@ def _set_np_symbol_class(cls): _np_symbol_cls = cls -cdef NewSymbol(SymbolHandle handle, int is_np_op): +cdef NewSymbol(SymbolHandle handle, int is_np_sym=0): """Create a new symbol given handle""" - if is_np_op: - sym = _np_symbol_cls(None) - else: - sym = _symbol_cls(None) + create_symbol_fn = _np_symbol_cls if is_np_sym else _symbol_cls + sym = create_symbol_fn(None) (sym).chandle = handle return sym -def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op): +def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op=0): cdef unsigned long long ihandle = handle cdef OpHandle chandle = ihandle cdef vector[string] ckeys