Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
address the comment
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Sep 3, 2019
1 parent 675b815 commit d0b858b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 24 deletions.
4 changes: 3 additions & 1 deletion python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

"""Namespace for operators used in Gluon dispatched by F=ndarray."""
from __future__ import absolute_import
import numpy as np
from ...context import current_context
from . import _internal as _npi
from ..ndarray import NDArray
from ...base import numeric_types


Expand Down Expand Up @@ -237,7 +239,7 @@ def multinomial(n, pvals, size=None):
return _npi.multinomial(pvals, pvals=None, n=n, size=size)
else:
if isinstance(pvals, np.ndarray):
pvals = pvals.tolist()
raise ValueError('numpy ndarray is not supported!')
if any(isinstance(i, list) for i in pvals):
raise ValueError('object too deep for desired array')
return _npi.multinomial(n=n, pvals=pvals, size=size)
41 changes: 18 additions & 23 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,29 +774,24 @@ def test_np_multinomial():
pvals_list = [[0.0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0.0]]
sizes = [None, (), (3,), (2, 5, 7), (4, 9)]
experiements = 10000
for pvals_type in [list, _np.ndarray]:
for have_size in [False, True]:
for pvals in pvals_list:
if have_size:
for size in sizes:
if pvals_type == mx.nd.NDArray:
pvals = mx.nd.array(pvals).as_np_ndarray()
elif pvals_type == _np.ndarray:
pvals = _np.array(pvals)
freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy() / _np.float32(experiements)
# for those cases that didn't need reshape
if size in [None, ()]:
mx.test_utils.assert_almost_equal(freq, pvals, rtol=0.20, atol=1e-1)
else:
# check the shape
assert freq.shape == size + (len(pvals),), 'freq.shape={}, size + (len(pvals))={}'.format(freq.shape, size + (len(pvals)))
freq = freq.reshape((-1, len(pvals)))
# check the value for each row
for i in range(freq.shape[0]):
mx.test_utils.assert_almost_equal(freq[i, :], pvals, rtol=0.20, atol=1e-1)
else:
freq = mx.np.random.multinomial(experiements, pvals).asnumpy() / _np.float32(experiements)
mx.test_utils.assert_almost_equal(freq, pvals, rtol=0.20, atol=1e-1)
for have_size in [False, True]:
for pvals in pvals_list:
if have_size:
for size in sizes:
freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy() / _np.float32(experiements)
# for those cases that didn't need reshape
if size in [None, ()]:
mx.test_utils.assert_almost_equal(freq, pvals, rtol=0.20, atol=1e-1)
else:
# check the shape
assert freq.shape == size + (len(pvals),), 'freq.shape={}, size + (len(pvals))={}'.format(freq.shape, size + (len(pvals)))
freq = freq.reshape((-1, len(pvals)))
# check the value for each row
for i in range(freq.shape[0]):
mx.test_utils.assert_almost_equal(freq[i, :], pvals, rtol=0.20, atol=1e-1)
else:
freq = mx.np.random.multinomial(experiements, pvals).asnumpy() / _np.float32(experiements)
mx.test_utils.assert_almost_equal(freq, pvals, rtol=0.20, atol=1e-1)
# check the zero dimension
sizes = [(0), (0, 2), (4, 0, 2), (3, 0, 1, 2, 0)]
for pvals in pvals_list:
Expand Down

0 comments on commit d0b858b

Please sign in to comment.