Skip to content

Commit

Permalink
Extract and use strategy for finite floats
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic authored and julianStreibel committed May 7, 2024
1 parent 665f4f0 commit d79d7e6
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 15 deletions.
8 changes: 8 additions & 0 deletions tests/hypothesis_strategies/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Strategies for basic types."""

from functools import partial

import hypothesis.strategies as st

finite_floats = partial(st.floats, allow_infinity=False, allow_nan=False)
"""A strategy producing finite (i.e., non-nan and non-infinite) floats."""
19 changes: 10 additions & 9 deletions tests/hypothesis_strategies/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,39 @@
SmoothedBoxPrior,
)

from .basic import finite_floats
from .utils import intervals

gamma_priors = st.builds(
GammaPrior,
st.floats(min_value=0, exclude_min=True),
st.floats(min_value=0, exclude_min=True),
finite_floats(min_value=0.0, exclude_min=True),
finite_floats(min_value=0.0, exclude_min=True),
)
"""A strategy that generates Gamma priors."""

half_cauchy_priors = st.builds(
HalfCauchyPrior,
st.floats(min_value=0, exclude_min=True),
finite_floats(min_value=0.0, exclude_min=True),
)
"""A strategy that generates Half-Cauchy priors."""

normal_priors = st.builds(
NormalPrior,
st.floats(allow_nan=False, allow_infinity=False),
st.floats(min_value=0, exclude_min=True),
finite_floats(),
finite_floats(min_value=0.0, exclude_min=True),
)
"""A strategy that generates Normal priors."""

half_normal_priors = st.builds(
HalfNormalPrior,
st.floats(min_value=0, exclude_min=True),
finite_floats(min_value=0.0, exclude_min=True),
)
"""A strategy that generates Half-Normal priors."""

log_normal_priors = st.builds(
LogNormalPrior,
st.floats(allow_nan=False, allow_infinity=False),
st.floats(min_value=0, exclude_min=True),
finite_floats(),
finite_floats(min_value=0.0, exclude_min=True),
)
"""A strategy that generates Log-Normal priors."""

Expand All @@ -52,7 +53,7 @@ def _smoothed_box_priors(draw: st.DrawFn):
"""A strategy that generates Smoothed-Box priors."""
interval = draw(intervals(exclude_half_bounded=True, exclude_fully_unbounded=True))
sigma = draw(
st.floats(min_value=0, exclude_min=True),
finite_floats(min_value=0.0, exclude_min=True),
)

return SmoothedBoxPrior(*interval.to_tuple(), sigma)
Expand Down
11 changes: 5 additions & 6 deletions tests/hypothesis_strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from baybe.utils.interval import Interval

from .basic import finite_floats


class IntervalType(Enum):
"""The possible types of an interval on the real number line."""
Expand Down Expand Up @@ -38,18 +40,15 @@ def intervals(
allowed_types = [t for t, b in type_gate.items() if b]
interval_type = draw(st.sampled_from(allowed_types))

# A strategy producing finite floats
ffloats = st.floats(allow_infinity=False, allow_nan=False)

# Draw the bounds depending on the interval type
if interval_type is IntervalType.FULLY_UNBOUNDED:
bounds = (None, None)
elif interval_type is IntervalType.HALF_BOUNDED:
bounds = draw(
st.sampled_from(
[
(None, draw(ffloats)),
(draw(ffloats), None),
(None, draw(finite_floats())),
(draw(finite_floats()), None),
]
)
)
Expand All @@ -58,7 +57,7 @@ def intervals(
hnp.arrays(
dtype=float,
shape=(2,),
elements=ffloats,
elements=finite_floats(),
unique=True,
).map(sorted)
)
Expand Down

0 comments on commit d79d7e6

Please sign in to comment.