From d79d7e665a141a37bf78225e989b4b5acedc2c8d Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 6 May 2024 21:51:47 +0200 Subject: [PATCH] Extract and use strategy for finite floats --- tests/hypothesis_strategies/basic.py | 8 ++++++++ tests/hypothesis_strategies/priors.py | 19 ++++++++++--------- tests/hypothesis_strategies/utils.py | 11 +++++------ 3 files changed, 23 insertions(+), 15 deletions(-) create mode 100644 tests/hypothesis_strategies/basic.py diff --git a/tests/hypothesis_strategies/basic.py b/tests/hypothesis_strategies/basic.py new file mode 100644 index 0000000000..3df4bbfd2d --- /dev/null +++ b/tests/hypothesis_strategies/basic.py @@ -0,0 +1,8 @@ +"""Strategies for basic types.""" + +from functools import partial + +import hypothesis.strategies as st + +finite_floats = partial(st.floats, allow_infinity=False, allow_nan=False) +"""A strategy producing finite (i.e., non-nan and non-infinite) floats.""" diff --git a/tests/hypothesis_strategies/priors.py b/tests/hypothesis_strategies/priors.py index 7eeee29cde..c438f94046 100644 --- a/tests/hypothesis_strategies/priors.py +++ b/tests/hypothesis_strategies/priors.py @@ -11,38 +11,39 @@ SmoothedBoxPrior, ) +from .basic import finite_floats from .utils import intervals gamma_priors = st.builds( GammaPrior, - st.floats(min_value=0, exclude_min=True), - st.floats(min_value=0, exclude_min=True), + finite_floats(min_value=0.0, exclude_min=True), + finite_floats(min_value=0.0, exclude_min=True), ) """A strategy that generates Gamma priors.""" half_cauchy_priors = st.builds( HalfCauchyPrior, - st.floats(min_value=0, exclude_min=True), + finite_floats(min_value=0.0, exclude_min=True), ) """A strategy that generates Half-Cauchy priors.""" normal_priors = st.builds( NormalPrior, - st.floats(allow_nan=False, allow_infinity=False), - st.floats(min_value=0, exclude_min=True), + finite_floats(), + finite_floats(min_value=0.0, exclude_min=True), ) """A strategy that generates Normal priors.""" half_normal_priors = st.builds( HalfNormalPrior, - st.floats(min_value=0, exclude_min=True), + finite_floats(min_value=0.0, exclude_min=True), ) """A strategy that generates Half-Normal priors.""" log_normal_priors = st.builds( LogNormalPrior, - st.floats(allow_nan=False, allow_infinity=False), - st.floats(min_value=0, exclude_min=True), + finite_floats(), + finite_floats(min_value=0.0, exclude_min=True), ) """A strategy that generates Log-Normal priors.""" @@ -52,7 +53,7 @@ def _smoothed_box_priors(draw: st.DrawFn): """A strategy that generates Smoothed-Box priors.""" interval = draw(intervals(exclude_half_bounded=True, exclude_fully_unbounded=True)) sigma = draw( - st.floats(min_value=0, exclude_min=True), + finite_floats(min_value=0.0, exclude_min=True), ) return SmoothedBoxPrior(*interval.to_tuple(), sigma) diff --git a/tests/hypothesis_strategies/utils.py b/tests/hypothesis_strategies/utils.py index 53e8905db8..fb57555dae 100644 --- a/tests/hypothesis_strategies/utils.py +++ b/tests/hypothesis_strategies/utils.py @@ -7,6 +7,8 @@ from baybe.utils.interval import Interval +from .basic import finite_floats + class IntervalType(Enum): """The possible types of an interval on the real number line.""" @@ -38,9 +40,6 @@ def intervals( allowed_types = [t for t, b in type_gate.items() if b] interval_type = draw(st.sampled_from(allowed_types)) - # A strategy producing finite floats - ffloats = st.floats(allow_infinity=False, allow_nan=False) - # Draw the bounds depending on the interval type if interval_type is IntervalType.FULLY_UNBOUNDED: bounds = (None, None) @@ -48,8 +47,8 @@ def intervals( bounds = draw( st.sampled_from( [ - (None, draw(ffloats)), - (draw(ffloats), None), + (None, draw(finite_floats())), + (draw(finite_floats()), None), ] ) ) @@ -58,7 +57,7 @@ def intervals( hnp.arrays( dtype=float, shape=(2,), - elements=ffloats, + elements=finite_floats(), unique=True, ).map(sorted) )