diff --git a/python/mxnet/cython/ndarray.pyx b/python/mxnet/cython/ndarray.pyx index 6a066badc2e2..71b464a9da51 100644 --- a/python/mxnet/cython/ndarray.pyx +++ b/python/mxnet/cython/ndarray.pyx @@ -76,12 +76,10 @@ def _set_np_ndarray_class(cls): _np_ndarray_cls = cls -cdef NewArray(NDArrayHandle handle, int is_np_op, int stype=-1): +cdef NewArray(NDArrayHandle handle, int stype=-1, int is_np_array=0): """Create a new array given handle""" - if is_np_op: - return _np_ndarray_cls(_ctypes.cast(handle, _ctypes.c_void_p), stype=stype) - else: - return _ndarray_cls(_ctypes.cast(handle, _ctypes.c_void_p), stype=stype) + create_array_fn = _np_ndarray_cls if is_np_array else _ndarray_cls + return create_array_fn(_ctypes.cast(handle, _ctypes.c_void_p), stype=stype) cdef class CachedOp: diff --git a/python/mxnet/cython/symbol.pyx b/python/mxnet/cython/symbol.pyx index bbb609002318..1d28f987a6a3 100644 --- a/python/mxnet/cython/symbol.pyx +++ b/python/mxnet/cython/symbol.pyx @@ -96,12 +96,10 @@ def _set_np_symbol_class(cls): _np_symbol_cls = cls -cdef NewSymbol(SymbolHandle handle, int is_np_op): +cdef NewSymbol(SymbolHandle handle, int is_np_sym=0): """Create a new symbol given handle""" - if is_np_op: - sym = _np_symbol_cls(None) - else: - sym = _symbol_cls(None) + create_symbol_fn = _np_symbol_cls if is_np_sym else _symbol_cls + sym = create_symbol_fn(None) (sym).chandle = handle return sym