Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 6907838

Browse files
committed
allow mixed types in array func protocol
1 parent 8b368ac commit 6907838

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

python/mxnet/numpy/multiarray.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from ..context import current_context
4949
from ..ndarray import numpy as _mx_nd_np
5050
from ..ndarray.numpy import _internal as _npi
51-
from ..ndarray.ndarray import _storage_type, from_numpy
51+
from ..ndarray.ndarray import _storage_type
5252
from .utils import _get_np_op
5353
from .fallback import * # pylint: disable=wildcard-import,unused-wildcard-import
5454
from . import fallback
@@ -148,10 +148,12 @@ def _reshape_view(a, *shape): # pylint: disable=redefined-outer-name
148148

149149
def _as_mx_np_array(object, ctx=None):
150150
"""Convert object to mxnet.numpy.ndarray."""
151-
if isinstance(object, _np.ndarray):
151+
if isinstance(object, ndarray):
152+
return object
153+
elif isinstance(object, _np.ndarray):
152154
if not object.flags['C_CONTIGUOUS']:
153155
object = _np.ascontiguousarray(object, dtype=object.dtype)
154-
ret = from_numpy(object)
156+
ret = array(object, ctx=ctx)
155157
return ret if ctx is None else ret.as_in_ctx(ctx=ctx)
156158
elif isinstance(object, (integer_types, numeric_types)):
157159
return object
@@ -358,11 +360,17 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad-
358360
out = func(*new_args, **new_kwargs)
359361
return _as_mx_np_array(out, ctx=cur_ctx)
360362
else:
361-
# Note: this allows subclasses that don't override
362-
# __array_function__ to handle mxnet.numpy.ndarray objects
363-
if not py_all(issubclass(t, ndarray) for t in types):
364-
return NotImplemented
365-
return mx_np_func(*args, **kwargs)
363+
if py_all(issubclass(t, ndarray) for t in types):
364+
return mx_np_func(*args, **kwargs)
365+
else:
366+
try:
367+
cur_ctx = next(a.ctx for a in args if hasattr(a, 'ctx'))
368+
except StopIteration:
369+
cur_ctx = next(a.ctx for a in kwargs.values() if hasattr(a, 'ctx'))
370+
new_args = _as_mx_np_array(args, ctx=cur_ctx)
371+
new_kwargs = {k: _as_mx_np_array(v, cur_ctx) for k, v in kwargs.items()}
372+
return mx_np_func(*new_args, **new_kwargs)
373+
366374

367375
def _get_np_basic_indexing(self, key):
368376
"""

tests/python/unittest/test_numpy_ndarray.py

+7
Original file line numberDiff line numberDiff line change
@@ -1385,6 +1385,7 @@ def test_from_numpy(np_array, zero_copy):
13851385
mx_array = mx.npx.from_numpy(np_array, zero_copy=zero_copy)
13861386
mx.test_utils.assert_almost_equal(np_array, mx_array.asnumpy())
13871387

1388+
@use_np
13881389
def test_from_numpy_exception():
13891390
np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32")
13901391
mx_array = mx.npx.from_numpy(np_array)
@@ -1397,3 +1398,9 @@ def test_from_numpy_exception():
13971398
assert not np_array.flags["C_CONTIGUOUS"]
13981399
with pytest.raises(ValueError):
13991400
mx_array = mx.nd.from_numpy(np_array)
1401+
1402+
@use_np
1403+
def test_mixed_array_types():
1404+
np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32")
1405+
mx_array = mx.np.ones((3, 1))
1406+
assert_almost_equal(mx_array + np_array, 1+np_array)

0 commit comments

Comments
 (0)