|
48 | 48 | from ..context import current_context
|
49 | 49 | from ..ndarray import numpy as _mx_nd_np
|
50 | 50 | from ..ndarray.numpy import _internal as _npi
|
51 |
| -from ..ndarray.ndarray import _storage_type, from_numpy |
| 51 | +from ..ndarray.ndarray import _storage_type |
52 | 52 | from .utils import _get_np_op
|
53 | 53 | from .fallback import * # pylint: disable=wildcard-import,unused-wildcard-import
|
54 | 54 | from . import fallback
|
@@ -148,10 +148,10 @@ def _reshape_view(a, *shape): # pylint: disable=redefined-outer-name
|
148 | 148 |
|
149 | 149 | def _as_mx_np_array(object, ctx=None):
|
150 | 150 | """Convert object to mxnet.numpy.ndarray."""
|
151 |
| - if isinstance(object, _np.ndarray): |
152 |
| - if not object.flags['C_CONTIGUOUS']: |
153 |
| - object = _np.ascontiguousarray(object, dtype=object.dtype) |
154 |
| - ret = from_numpy(object, array_cls=ndarray) |
| 151 | + if isinstance(object, ndarray): |
| 152 | + return object |
| 153 | + elif isinstance(object, _np.ndarray): |
| 154 | + ret = array(object, dtype=object.dtype, ctx=ctx) |
155 | 155 | return ret if ctx is None else ret.as_in_ctx(ctx=ctx)
|
156 | 156 | elif isinstance(object, (integer_types, numeric_types)):
|
157 | 157 | return object
|
@@ -358,11 +358,17 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad-
|
358 | 358 | out = func(*new_args, **new_kwargs)
|
359 | 359 | return _as_mx_np_array(out, ctx=cur_ctx)
|
360 | 360 | else:
|
361 |
| - # Note: this allows subclasses that don't override |
362 |
| - # __array_function__ to handle mxnet.numpy.ndarray objects |
363 |
| - if not py_all(issubclass(t, ndarray) for t in types): |
364 |
| - return NotImplemented |
365 |
| - return mx_np_func(*args, **kwargs) |
| 361 | + if py_all(issubclass(t, ndarray) for t in types): |
| 362 | + return mx_np_func(*args, **kwargs) |
| 363 | + else: |
| 364 | + try: |
| 365 | + cur_ctx = next(a.ctx for a in args if hasattr(a, 'ctx')) |
| 366 | + except StopIteration: |
| 367 | + cur_ctx = next(a.ctx for a in kwargs.values() if hasattr(a, 'ctx')) |
| 368 | + new_args = _as_mx_np_array(args, ctx=cur_ctx) |
| 369 | + new_kwargs = {k: _as_mx_np_array(v, cur_ctx) for k, v in kwargs.items()} |
| 370 | + return mx_np_func(*new_args, **new_kwargs) |
| 371 | + |
366 | 372 |
|
367 | 373 | def _get_np_basic_indexing(self, key):
|
368 | 374 | """
|
|
0 commit comments