Skip to content

Commit

Permalink
fix(#634): added a generate function to support SQL Numeric field (#636)
Browse files Browse the repository at this point in the history
  • Loading branch information
nisemenov authored Feb 23, 2025
1 parent 34182d4 commit e4a27ca
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 66 deletions.
93 changes: 54 additions & 39 deletions polyfactory/value_generators/constrained_numbers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from decimal import Decimal
from decimal import ROUND_DOWN, Decimal, localcontext
from sys import float_info
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast

from polyfactory.exceptions import ParameterException
from polyfactory.utils.deprecation import check_for_deprecated_parameters
from polyfactory.value_generators.primitives import create_random_decimal, create_random_float, create_random_integer

if TYPE_CHECKING:
Expand Down Expand Up @@ -135,6 +136,8 @@ def get_value_or_none(
le: T | None = None,
gt: T | None = None,
ge: T | None = None,
max_digits: int | None = None,
decimal_places: int | None = None,
) -> tuple[T | None, T | None]:
"""Return an optional value.
Expand All @@ -157,6 +160,26 @@ def get_value_or_none(
maximum_value = lt - get_increment(t_type)
else:
maximum_value = None

if max_digits is not None:
max_whole_digits = 10
whole_digits = max_digits - decimal_places if decimal_places is not None else max_digits
maximum = (
Decimal(10**whole_digits - 1) if whole_digits < max_whole_digits else Decimal(10**max_whole_digits - 1)
)
minimum = maximum * (-1)

if minimum_value is None or minimum_value < minimum:
minimum_value = t_type(minimum)
elif minimum_value > maximum:
msg = f"minimum value must be less than {maximum}"
raise ParameterException(msg)

if maximum_value is None or maximum_value > maximum:
maximum_value = t_type(maximum if maximum > 0 else Decimal(1))
elif maximum_value < minimum:
msg = f"maximum value must be greater than {minimum}"
raise ParameterException(msg)
return minimum_value, maximum_value


Expand All @@ -168,6 +191,8 @@ def get_constrained_number_range(
gt: T | None = None,
ge: T | None = None,
multiple_of: T | None = None,
max_digits: int | None = None,
decimal_places: int | None = None,
) -> tuple[T | None, T | None]:
"""Return the minimum and maximum values given a field_meta's constraints.
Expand All @@ -178,11 +203,15 @@ def get_constrained_number_range(
:param gt: Greater than value.
:param ge: Greater than or equal value.
:param multiple_of: Multiple of value.
:param decimal_places: Number of decimal places.
:param max_digits: Maximal number of digits.
:returns: a tuple of optional minimum and maximum values.
"""
seed = t_type(random.random() * 10)
minimum, maximum = get_value_or_none(lt=lt, le=le, gt=gt, ge=ge, t_type=t_type)
minimum, maximum = get_value_or_none(
lt=lt, le=le, gt=gt, ge=ge, t_type=t_type, max_digits=max_digits, decimal_places=decimal_places
)

if minimum is not None and maximum is not None and maximum < minimum:
msg = "maximum value must be greater than minimum value"
Expand Down Expand Up @@ -329,19 +358,14 @@ def validate_max_digits(
:returns: 'None'
"""
check_for_deprecated_parameters("2.19.1", parameters=(("minimum", minimum),))

if max_digits <= 0:
msg = "max_digits must be greater than 0"
raise ParameterException(msg)

if minimum is not None:
min_str = str(minimum).split(".")[1] if "." in str(minimum) else str(minimum)

if max_digits <= len(min_str):
msg = "minimum is greater than max_digits"
raise ParameterException(msg)

if decimal_places is not None and max_digits <= decimal_places:
msg = "max_digits must be greater than decimal places"
if decimal_places is not None and max_digits < decimal_places:
msg = "max_digits must be greater or equal than decimal places"
raise ParameterException(msg)


Expand All @@ -357,25 +381,18 @@ def handle_decimal_length(
:param max_digits: Maximal number of digits.
"""
string_number = str(generated_decimal)
sign = "-" if "-" in string_number else "+"
string_number = string_number.replace("-", "")
whole_numbers, decimals = string_number.split(".")

if (max_digits is not None and decimal_places is not None and len(whole_numbers) + decimal_places > max_digits) or (
(max_digits is None or decimal_places is None) and max_digits is not None
):
max_decimals = max_digits - len(whole_numbers)
elif max_digits is not None:
max_decimals = decimal_places # type: ignore[assignment]
else:
max_decimals = cast("int", decimal_places)

if max_decimals < 0: # pyright: ignore[reportOptionalOperand]
return Decimal(sign + whole_numbers[:max_decimals])

decimals = decimals[:max_decimals]
return Decimal(sign + whole_numbers + "." + decimals[:decimal_places])
with localcontext() as ctx:
ctx.rounding = ROUND_DOWN
list_decimal = str(generated_decimal).strip("-0").split(".")
decimal_parts = 2
if len(list_decimal) == decimal_parts:
whole, decimals = list_decimal
if decimal_places is not None and len(decimals) > decimal_places:
return round(generated_decimal, decimal_places)
if max_digits is not None and len(whole) + len(decimals) > max_digits:
max_decimals = max_digits - len(whole)
return round(generated_decimal, max_decimals)
return generated_decimal


def handle_constrained_decimal(
Expand All @@ -402,13 +419,14 @@ def handle_constrained_decimal(
:returns: A decimal.
"""

minimum, maximum = get_constrained_number_range(
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
max_digits=max_digits,
decimal_places=decimal_places,
t_type=Decimal,
random=random,
)
Expand All @@ -424,11 +442,8 @@ def handle_constrained_decimal(
method=create_random_decimal,
)

if max_digits is not None or decimal_places is not None:
return handle_decimal_length(
generated_decimal=generated_decimal,
max_digits=max_digits,
decimal_places=decimal_places,
)

return generated_decimal
return handle_decimal_length(
generated_decimal=generated_decimal,
max_digits=max_digits,
decimal_places=decimal_places,
)
58 changes: 31 additions & 27 deletions tests/constraints/test_decimal_constraints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from decimal import Decimal
from random import Random
from typing import Optional, cast
from typing import cast

import pytest
from hypothesis import given
Expand Down Expand Up @@ -32,12 +32,23 @@ def test_handle_constrained_decimal_without_constraints() -> None:
assert isinstance(result, Decimal)


def test_handle_constrained_decimal_length_validation() -> None:
with pytest.raises(ParameterException):
@pytest.mark.parametrize(
("msg", "ge", "le"),
(
("minimum value must be less than", Decimal("100.000"), Decimal()),
("maximum value must be greater than", Decimal(), Decimal("-100.000")),
),
)
def test_handle_constrained_decimal_length_validation(msg: str, ge: Decimal, le: Decimal) -> None:
with pytest.raises(
ParameterException,
match=msg,
):
handle_constrained_decimal(
random=Random(),
max_digits=2,
ge=Decimal("100.000"),
ge=ge,
le=le,
)


Expand All @@ -48,9 +59,12 @@ def test_handle_constrained_decimal_handles_max_digits(max_digits: int) -> None:
random=Random(),
max_digits=max_digits,
)
assert len(result.as_tuple().digits) - abs(cast("int", result.as_tuple().exponent)) <= max_digits
assert len(result.as_tuple().digits) <= max_digits
else:
with pytest.raises(ParameterException):
with pytest.raises(
ParameterException,
match="max_digits must be greater than 0",
):
handle_constrained_decimal(
random=Random(),
max_digits=max_digits,
Expand All @@ -66,15 +80,23 @@ def test_handle_constrained_decimal_handles_decimal_places(decimal_places: int)
assert abs(cast("int", result.as_tuple().exponent)) <= decimal_places


@given(integers(min_value=0, max_value=100), integers(min_value=1, max_value=100))
@given(integers(min_value=0, max_value=100), integers(min_value=0, max_value=100))
def test_handle_constrained_decimal_handles_max_digits_and_decimal_places(max_digits: int, decimal_places: int) -> None:
if max_digits > 0 and max_digits > decimal_places:
if max_digits > 0 and max_digits >= decimal_places:
result = handle_constrained_decimal(
random=Random(),
decimal_places=decimal_places,
max_digits=max_digits,
)
assert len(result.as_tuple().digits) - abs(cast("int", result.as_tuple().exponent)) <= max_digits
non_fractionals = max_digits - decimal_places
list_value = str(result).strip("-").split(".")
if decimal_places:
left_digits, right_digits = list_value
assert len(right_digits) <= decimal_places
else:
left_digits = list_value[0]
assert len(left_digits) <= non_fractionals or int(left_digits) == non_fractionals

else:
with pytest.raises(ParameterException):
handle_constrained_decimal(
Expand Down Expand Up @@ -368,16 +390,6 @@ class PersonFactory(ModelFactory):
def test_handle_decimal_length() -> None:
decimal = Decimal("999.9999999")

# here digits should determine decimal length
max_digits = 5
decimal_places: Optional[int] = 5

result = handle_decimal_length(decimal, decimal_places, max_digits)

assert isinstance(result, Decimal)
assert len(result.as_tuple().digits) == 5
assert abs(cast("int", result.as_tuple().exponent)) == 2

# here decimal places should determine max length
max_digits = 10
decimal_places = 5
Expand Down Expand Up @@ -405,14 +417,6 @@ def test_handle_decimal_length() -> None:
assert len(result.as_tuple().digits) == 8
assert abs(cast("int", result.as_tuple().exponent)) == 5

# here max_decimals is below 0
decimal = Decimal("99.99")
max_digits = 1
result = handle_decimal_length(decimal, decimal_places, max_digits)
assert isinstance(result, Decimal)
assert len(result.as_tuple().digits) == 1
assert cast("int", result.as_tuple().exponent) == 0


def test_zero_to_one_range() -> None:
class FractionExample(BaseModel):
Expand Down
34 changes: 34 additions & 0 deletions tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,37 @@ class ModelFactory(SQLAlchemyFactory[Model]):
constrained_number: Decimal = instance.constrainted_number
assert isinstance(constrained_number, Decimal)
assert abs(len(constrained_number.as_tuple().digits) - abs(int(constrained_number.as_tuple().exponent))) <= 2


@pytest.mark.parametrize(
"numeric",
(
Numeric(),
Numeric(precision=4),
Numeric(precision=4, scale=0),
Numeric(precision=4, scale=2),
),
)
def test_numeric_field(numeric: Numeric) -> None:
_registry = registry()

class Base(metaclass=DeclarativeMeta):
__abstract__ = True
__allow_unmapped__ = True

registry = _registry
metadata = _registry.metadata

class NumericModel(Base):
__tablename__ = "numerics"

id: Any = Column(Integer(), primary_key=True)
numeric_field: Any = Column(numeric, nullable=False)

class NumericModelFactory(SQLAlchemyFactory[NumericModel]): ...

result = NumericModelFactory.get_model_fields()[1]
assert result.annotation is Decimal
if constraints := result.constraints:
assert constraints.get("max_digits") == numeric.precision
assert constraints.get("decimal_places") == numeric.scale

0 comments on commit e4a27ca

Please sign in to comment.