Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
stay with original API for backward compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Fan committed Jul 2, 2019
1 parent 2a571a1 commit c06b287
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
18 changes: 8 additions & 10 deletions python/mxnet/cython/ndarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,10 @@ def _set_np_ndarray_class(cls):
_np_ndarray_cls = cls


cdef NewArray(NDArrayHandle handle, int is_np_op, int stype=-1):
cdef NewArray(NDArrayHandle handle, int stype=-1, int is_np_array=0):
"""Create a new array given handle"""
if is_np_op:
return _np_ndarray_cls(_ctypes.cast(<unsigned long long>handle, _ctypes.c_void_p), stype=stype)
else:
return _ndarray_cls(_ctypes.cast(<unsigned long long>handle, _ctypes.c_void_p), stype=stype)
create_array_fn = _np_ndarray_cls if is_np_array else _ndarray_cls
return create_array_fn(_ctypes.cast(<unsigned long long>handle, _ctypes.c_void_p), stype=stype)


cdef class CachedOp:
Expand Down Expand Up @@ -167,12 +165,12 @@ cdef class CachedOp:
if original_output is not None:
return original_output
if num_output == 1:
return NewArray(p_output_vars[0], self.is_np_sym, p_output_stypes[0])
return NewArray(p_output_vars[0], p_output_stypes[0], self.is_np_sym)
else:
return [NewArray(p_output_vars[i], self.is_np_sym, p_output_stypes[i]) for i in range(num_output)]
return [NewArray(p_output_vars[i], p_output_stypes[i], self.is_np_sym) for i in range(num_output)]


def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op):
def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op=0):
"""cython implementation of imperative invoke wrapper"""
cdef unsigned long long ihandle = handle
cdef OpHandle chandle = <OpHandle>ihandle
Expand Down Expand Up @@ -224,6 +222,6 @@ def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op):
if original_output is not None:
return original_output
if num_output == 1:
return NewArray(p_output_vars[0], is_np_op, p_output_stypes[0])
return NewArray(p_output_vars[0], p_output_stypes[0], is_np_op)
else:
return [NewArray(p_output_vars[i], is_np_op, p_output_stypes[i]) for i in range(num_output)]
return [NewArray(p_output_vars[i], p_output_stypes[i], is_np_op) for i in range(num_output)]
10 changes: 4 additions & 6 deletions python/mxnet/cython/symbol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,15 @@ def _set_np_symbol_class(cls):
_np_symbol_cls = cls


cdef NewSymbol(SymbolHandle handle, int is_np_op):
cdef NewSymbol(SymbolHandle handle, int is_np_sym=0):
"""Create a new symbol given handle"""
if is_np_op:
sym = _np_symbol_cls(None)
else:
sym = _symbol_cls(None)
create_symbol_fn = _np_symbol_cls if is_np_sym else _symbol_cls
sym = create_symbol_fn(None)
(<SymbolBase>sym).chandle = handle
return sym


def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op):
def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op=0):
cdef unsigned long long ihandle = handle
cdef OpHandle chandle = <OpHandle>ihandle
cdef vector[string] ckeys
Expand Down

0 comments on commit c06b287

Please sign in to comment.