Skip to content

Commit c7d6173

Browse files
authored
Merge pull request #282 from bashtage/edge-cases
MAINT: Port output check from NumPy
2 parents fa533c5 + 58757ab commit c7d6173

7 files changed

+101
-34
lines changed

doc/source/change-log.rst

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ Change Log
1616
cannot update NumPy.
1717

1818

19-
Since v1.20.0
20-
=============
19+
v1.20.1
20+
=======
21+
- Fixed a bug that affects :func:`~randomgen.generator.Generator.standard_gamma` when
22+
used with ``out`` and a Fortran contiguous array.
2123
- Added :func:`~randomgen.generator.ExtendedGenerator.multivariate_complex_normal`.
2224
- Added :func:`~randomgen.generator.ExtendedGenerator.standard_wishart` and
2325
:func:`~randomgen.generator.ExtendedGenerator.wishart` variate generators.

doc/source/names_wordlist.txt

+8-1
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,11 @@ Wooldridge
115115
Zhenyu
116116
ecuyer
117117
Overton
118-
Horner
118+
Horner
119+
Feiveson
120+
Jour
121+
Uhlig
122+
Dıaz
123+
Garcıa
124+
Jáimez
125+
Mardia

doc/source/spelling_wordlist.txt

+5
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,8 @@ Intrinsics
279279
precomputed
280280
args
281281
kwargs
282+
Wishart
283+
wishart
284+
Fortran
285+
loc
286+
trivariate

randomgen/common.pyx

+33-12
Original file line numberDiff line numberDiff line change
@@ -533,23 +533,44 @@ cdef validate_output_shape(iter_shape, np.ndarray output):
533533
)
534534

535535

536-
cdef check_output(object out, object dtype, object size):
536+
cdef check_output(object out, object dtype, object size, bint require_c_array):
537+
"""
538+
Check user-supplied output array properties and shape
539+
540+
Parameters
541+
----------
542+
out : {ndarray, None}
543+
The array to check. If None, returns immediately.
544+
dtype : dtype
545+
The required dtype of out.
546+
size : {None, int, tuple[int]}
547+
The size passed. If out is an ndarray, verifies that the shape of out
548+
matches size.
549+
require_c_array : bool
550+
Whether out must be a C-array. If False, out can be either C- or F-
551+
ordered. If True, must be C-ordered. In either case, must be
552+
contiguous, writable, aligned and in native byte-order.
553+
"""
537554
if out is None:
538555
return
539556
cdef np.ndarray out_array = <np.ndarray>out
540-
if not (np.PyArray_CHKFLAGS(out_array, api.NPY_ARRAY_CARRAY) or
541-
np.PyArray_CHKFLAGS(out_array, api.NPY_ARRAY_FARRAY)):
542-
raise ValueError("Supplied output array is not contiguous, writable or aligned.")
557+
if not (np.PyArray_ISCARRAY(out_array) or
558+
(np.PyArray_ISFARRAY(out_array) and not require_c_array)):
559+
req = "C-" if require_c_array else ""
560+
raise ValueError(
561+
f'Supplied output array must be {req}contiguous, writable, '
562+
f'aligned, and in machine byte-order.'
563+
)
543564
if out_array.dtype != dtype:
544-
raise TypeError("Supplied output array has the wrong type. "
545-
"Expected {0}, got {0}".format(dtype, out_array.dtype))
565+
raise TypeError('Supplied output array has the wrong type. '
566+
'Expected {0}, got {1}'.format(np.dtype(dtype), out_array.dtype))
546567
if size is not None:
547568
try:
548569
tup_size = tuple(size)
549570
except TypeError:
550571
tup_size = tuple([size])
551572
if tup_size != out.shape:
552-
raise ValueError("size must match out.shape when used together")
573+
raise ValueError('size must match out.shape when used together')
553574

554575

555576
cdef object double_fill(void *func, bitgen_t *state, object size, object lock, object out):
@@ -565,7 +586,7 @@ cdef object double_fill(void *func, bitgen_t *state, object size, object lock, o
565586
return out_val
566587

567588
if out is not None:
568-
check_output(out, np.float64, size)
589+
check_output(out, np.float64, size, False)
569590
out_array = <np.ndarray>out
570591
else:
571592
out_array = <np.ndarray>np.empty(size, np.double)
@@ -587,7 +608,7 @@ cdef object float_fill(void *func, bitgen_t *state, object size, object lock, ob
587608
return random_func(state)
588609

589610
if out is not None:
590-
check_output(out, np.float32, size)
611+
check_output(out, np.float32, size, False)
591612
out_array = <np.ndarray>out
592613
else:
593614
out_array = <np.ndarray>np.empty(size, np.float32)
@@ -610,7 +631,7 @@ cdef object float_fill_from_double(void *func, bitgen_t *state, object size, obj
610631
return <float>random_func(state)
611632

612633
if out is not None:
613-
check_output(out, np.float32, size)
634+
check_output(out, np.float32, size, False)
614635
out_array = <np.ndarray>out
615636
else:
616637
out_array = <np.ndarray>np.empty(size, np.float32)
@@ -826,7 +847,7 @@ cdef object cont(void *func, void *state, object size, object lock, int narg,
826847
cdef np.ndarray a_arr, b_arr, c_arr
827848
cdef double _a = 0.0, _b = 0.0, _c = 0.0
828849
cdef bint is_scalar = True
829-
check_output(out, np.float64, size)
850+
check_output(out, np.float64, size, narg > 0)
830851
if narg > 0:
831852
a_arr = <np.ndarray>np.PyArray_FROM_OTF(a, np.NPY_DOUBLE, api.NPY_ARRAY_ALIGNED)
832853
is_scalar = is_scalar and np.PyArray_NDIM(a_arr) == 0
@@ -1276,7 +1297,7 @@ cdef object cont_f(void *func, bitgen_t *state, object size, object lock,
12761297
cdef float _a
12771298
cdef bint is_scalar = True
12781299
cdef int requirements = api.NPY_ARRAY_ALIGNED | api.NPY_ARRAY_FORCECAST
1279-
check_output(out, np.float32, size)
1300+
check_output(out, np.float32, size, True)
12801301
a_arr = <np.ndarray>np.PyArray_FROMANY(a, np.NPY_FLOAT32, 0, 0, requirements)
12811302
is_scalar = np.PyArray_NDIM(a_arr) == 0
12821303

randomgen/generator.pyx

+27-15
Original file line numberDiff line numberDiff line change
@@ -5355,7 +5355,7 @@ cdef class ExtendedGenerator:
53555355
53565356
Notes
53575357
-----
5358-
Uses the method of Odell and Fieveson [1]_ when `df` >= `dim`.
5358+
Uses the method of Odell and Feiveson [1]_ when `df` >= `dim`.
53595359
Otherwise variates are directly generated as the inner product
53605360
of `df` by `dim` arrays of standard normal random variates.
53615361
@@ -5405,12 +5405,12 @@ cdef class ExtendedGenerator:
54055405
"""
54065406
wishart(df, scale, size=None, *, check_valid="warn", tol=None, rank=None, method="svd")
54075407
5408-
Draw samples from the Wishart and psuedo-Wishart distributions.
5408+
Draw samples from the Wishart and pseudo-Wishart distributions.
54095409
54105410
Parameters
54115411
----------
54125412
df : {int, array_like[int]}
5413-
Degree-of-freedom values. In array-like must boradcast with all
5413+
Degree-of-freedom values. In array-like must broadcast with all
54145414
but the final two dimensions of ``shape``.
54155415
scale : array_like
54165416
Shape matrix of the distribution. It must be symmetric and
@@ -5459,7 +5459,7 @@ cdef class ExtendedGenerator:
54595459
54605460
Notes
54615461
-----
5462-
Uses the method of Odell and Fieveson [1]_ when `df` >= `dim`.
5462+
Uses the method of Odell and Feiveson [1]_ when `df` >= `dim`.
54635463
Otherwise variates are directly generated as the inner product
54645464
of `df` by `dim` arrays of standard normal random variates.
54655465
@@ -5491,10 +5491,10 @@ cdef class ExtendedGenerator:
54915491
shape_arr = <np.ndarray>np.asarray(scale, dtype=np.float64, order="C")
54925492
shape_nd = np.PyArray_NDIM(shape_arr)
54935493
msg = (
5494-
"scale must have at least 2 dimensions. The final two dimensions "
5495-
"must be the same so that scale's shape is (...,N,N)."
5494+
"scale must be non-empty and have at least 2 dimensions. The final "
5495+
"two dimensions must be the same so that scale's shape is (...,N,N)."
54965496
)
5497-
if shape_nd < 2:
5497+
if shape_nd < 2 or shape_arr.size == 0:
54985498
raise ValueError(msg)
54995499
dim = np.shape(shape_arr)[shape_nd-1]
55005500
rank_val = dim if rank is None else int(rank)
@@ -5720,8 +5720,8 @@ and the trailing dimensions must match exactly so that
57205720
np.NPY_ARRAY_ALIGNED |
57215721
np.NPY_ARRAY_C_CONTIGUOUS)
57225722
ldim = np.PyArray_NDIM(larr)
5723-
if ldim < 1:
5724-
raise ValueError("loc must be at least 1-dimensional")
5723+
if ldim < 1 or larr.size == 0:
5724+
raise ValueError("loc must be non-empty and at least 1-dimensional")
57255725
dim = np.PyArray_DIMS(larr)[ldim - 1]
57265726

57275727
if gamma is None:
@@ -5734,10 +5734,16 @@ and the trailing dimensions must match exactly so that
57345734

57355735
gdim = np.PyArray_NDIM(garr)
57365736
gshape = np.PyArray_DIMS(garr)
5737-
if gdim < 2 or gshape[gdim - 2] != gshape[gdim - 1] or gshape[gdim - 1] != dim:
5737+
if (
5738+
gdim < 2 or
5739+
gshape[gdim - 2] != gshape[gdim - 1] or
5740+
gshape[gdim - 1] != dim or
5741+
garr.size == 0
5742+
):
57385743
raise ValueError(
5739-
"gamma must be at least 2-dimensional and the final two dimensions "
5740-
f"must match the final dimension of loc, {dim}."
5744+
"gamma must be non-empty with at least 2-dimensional and the "
5745+
"final two dimensions must match the final dimension of loc,"
5746+
f" {dim}."
57415747
)
57425748
if relation is None:
57435749
rarr = <np.ndarray>np.zeros((dim,dim), dtype=complex)
@@ -5748,10 +5754,16 @@ and the trailing dimensions must match exactly so that
57485754
np.NPY_ARRAY_C_CONTIGUOUS)
57495755
rdim = np.PyArray_NDIM(rarr)
57505756
rshape = np.PyArray_DIMS(rarr)
5751-
if rdim < 2 or rshape[rdim - 2] != rshape[rdim - 1] or rshape[rdim - 1] != dim:
5757+
if (
5758+
rdim < 2 or
5759+
rshape[rdim - 2] != rshape[rdim - 1] or
5760+
rshape[rdim - 1] != dim or
5761+
rarr.size == 0
5762+
):
57525763
raise ValueError(
5753-
"relation must be at least 2-dimensional and the final two dimensions "
5754-
f"must match the final dimension of loc, {dim}."
5764+
"relation must be non-empty with at least 2-dimensional and the "
5765+
"final two dimensions must match the final dimension of loc,"
5766+
f" {dim}."
57555767
)
57565768
can_bcast, cov_shape = broadcast_shape(np.shape(garr), np.shape(rarr), False)
57575769
if not can_bcast:

randomgen/rdrand.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ cdef class RDRAND(BitGenerator):
178178
state structure, or use PyErr_Occurred to see if an error occurred
179179
during generation.
180180
181-
To see the exception you will generatr, you can run this invalid code
181+
To see the exception you will generate, you can run this invalid code
182182
183183
>>> from numpy.random import Generator
184184
>>> from randomgen import RDRAND

randomgen/tests/test_extended_generator.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -436,11 +436,11 @@ def test_missing_scipy_exception():
436436

437437
def test_wishart_exceptions():
438438
eg = ExtendedGenerator()
439-
with pytest.raises(ValueError, match="scale must have at"):
439+
with pytest.raises(ValueError, match="scale must be non-empty"):
440440
eg.wishart(10, [10])
441-
with pytest.raises(ValueError, match="scale must have at"):
441+
with pytest.raises(ValueError, match="scale must be non-empty"):
442442
eg.wishart(10, 10)
443-
with pytest.raises(ValueError, match="scale must have at"):
443+
with pytest.raises(ValueError, match="scale must be non-empty"):
444444
eg.wishart(10, np.array([[1, 2]]))
445445
with pytest.raises(ValueError, match="At least one"):
446446
eg.wishart([], np.eye(2))
@@ -596,3 +596,23 @@ def test_mv_complex_normal_exceptions(extended_gen):
596596
extended_gen.multivariate_complex_normal(
597597
[0.0, 0.0], np.ones((4, 1, 3, 2, 2)), np.ones((1, 1, 2, 2, 2))
598598
)
599+
600+
601+
def test_wishart_edge(extended_gen):
602+
with pytest.raises(ValueError, match="scale must be non-empty"):
603+
extended_gen.wishart(5, np.empty((0, 0)))
604+
with pytest.raises(ValueError, match="scale must be non-empty"):
605+
extended_gen.wishart(5, np.empty((0, 2, 2)))
606+
with pytest.raises(ValueError, match="scale must be non-empty"):
607+
extended_gen.wishart(5, [[]])
608+
with pytest.raises(ValueError, match="At least one value is required"):
609+
extended_gen.wishart(np.empty((0, 2, 3)), np.eye(2))
610+
611+
612+
def test_mv_complex_normal_edge(extended_gen):
613+
with pytest.raises(ValueError, match="loc must be non-empty and at least"):
614+
extended_gen.multivariate_complex_normal(np.empty((0, 2)))
615+
with pytest.raises(ValueError, match="gamma must be non-empty"):
616+
extended_gen.multivariate_complex_normal([0, 0], np.empty((0, 2, 2)))
617+
with pytest.raises(ValueError, match="relation must be non-empty"):
618+
extended_gen.multivariate_complex_normal([0, 0], np.eye(2), np.empty((0, 2, 2)))

0 commit comments

Comments
 (0)