Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] where does not support python scalar as the input #17179

Closed
sxjscience opened this issue Dec 26, 2019 · 8 comments
Closed

[Numpy] where does not support python scalar as the input #17179

sxjscience opened this issue Dec 26, 2019 · 8 comments
Assignees

Comments

@sxjscience
Copy link
Member

import mxnet as mx
mx.npx.set_np()
a = mx.np.sym.var('a')
a = mx.sym.var('a').as_np_ndarray()
mx.sym.np.where(a, a, 0)

Error message:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-5-e7a6fe89b743> in <module>
----> 1 mx.sym.np.where(a, a, 0)

~/mxnet/python/mxnet/symbol/numpy/_symbol.py in where(condition, x, y)
   5501 
   5502     """
-> 5503     return _npi.where(condition, x, y, out=None)
   5504 
   5505 

~/mxnet/python/mxnet/symbol/register.py in where(condition, x, y, name, attr, out, **kwargs)

AssertionError: Argument y must be Symbol instances, but got 0

Also, the imperative case:

import mxnet as mx
mx.npx.set_np()
mx.np.where(mx.np.ones((10, )), mx.np.ones((10, )), 0)

Error message:

AssertionError                            Traceback (most recent call last)
<ipython-input-14-bdddf3065582> in <module>
----> 1 mx.np.where(mx.np.ones((10, )), mx.np.ones((10, )), 0)

~/mxnet/python/mxnet/numpy/multiarray.py in where(condition, x, y)
   7996            [ 0.,  3., -1.]])
   7997     """
-> 7998     return _mx_nd_np.where(condition, x, y)
   7999 
   8000 

~/mxnet/python/mxnet/ndarray/numpy/_op.py in where(condition, x, y)
   6035         return nonzero(condition)
   6036     else:
-> 6037         return _npi.where(condition, x, y, out=None)
   6038 
   6039 

~/mxnet/python/mxnet/ndarray/register.py in where(condition, x, y, out, name, **kwargs)

AssertionError: Argument y must have NDArray type, but got 0
@xidulu
Copy link
Contributor

xidulu commented Dec 27, 2019

I actually mentioned this problem in the PR, #16829 (comment)
The author suggested that you should wrap the scalar withnp.array.

@samskalicky
Copy link
Contributor

@apeforest assign [@reminisce ]

@sxjscience
Copy link
Member Author

sxjscience commented Dec 31, 2019

@xidulu @hgt312 In my use case, I'm using the symbolic interface so I cannot call mx.np.array.

@reminisce
Copy link
Contributor

@hgt312

@hgt312
Copy link
Contributor

hgt312 commented Dec 31, 2019

I will add the scalar version soon.

@hgt312
Copy link
Contributor

hgt312 commented Jan 6, 2020

If both x and y are scalars, dtype of the output should be what? In official numpy, it may be int64 or float64 according to the inputs' type.

@sxjscience
Copy link
Member Author

Are we able to make it numpy compatible?

@yzhliu
Copy link
Member

yzhliu commented Apr 29, 2020

Closed by #17249

@yzhliu yzhliu closed this as completed Apr 29, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

6 participants