From e4a27cab597f60f142210eaca66a3f97188810f0 Mon Sep 17 00:00:00 2001 From: Nikita Semenov <117774141+nisemenov@users.noreply.github.com> Date: Sun, 23 Feb 2025 17:14:02 +0300 Subject: [PATCH] fix(#634): added a generate function to support SQL Numeric field (#636) --- .../value_generators/constrained_numbers.py | 93 +++++++++++-------- tests/constraints/test_decimal_constraints.py | 58 ++++++------ .../test_sqlalchemy_factory_common.py | 34 +++++++ 3 files changed, 119 insertions(+), 66 deletions(-) diff --git a/polyfactory/value_generators/constrained_numbers.py b/polyfactory/value_generators/constrained_numbers.py index 3cf77dee..eaa8f0ac 100644 --- a/polyfactory/value_generators/constrained_numbers.py +++ b/polyfactory/value_generators/constrained_numbers.py @@ -1,10 +1,11 @@ from __future__ import annotations -from decimal import Decimal +from decimal import ROUND_DOWN, Decimal, localcontext from sys import float_info from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast from polyfactory.exceptions import ParameterException +from polyfactory.utils.deprecation import check_for_deprecated_parameters from polyfactory.value_generators.primitives import create_random_decimal, create_random_float, create_random_integer if TYPE_CHECKING: @@ -135,6 +136,8 @@ def get_value_or_none( le: T | None = None, gt: T | None = None, ge: T | None = None, + max_digits: int | None = None, + decimal_places: int | None = None, ) -> tuple[T | None, T | None]: """Return an optional value. @@ -157,6 +160,26 @@ def get_value_or_none( maximum_value = lt - get_increment(t_type) else: maximum_value = None + + if max_digits is not None: + max_whole_digits = 10 + whole_digits = max_digits - decimal_places if decimal_places is not None else max_digits + maximum = ( + Decimal(10**whole_digits - 1) if whole_digits < max_whole_digits else Decimal(10**max_whole_digits - 1) + ) + minimum = maximum * (-1) + + if minimum_value is None or minimum_value < minimum: + minimum_value = t_type(minimum) + elif minimum_value > maximum: + msg = f"minimum value must be less than {maximum}" + raise ParameterException(msg) + + if maximum_value is None or maximum_value > maximum: + maximum_value = t_type(maximum if maximum > 0 else Decimal(1)) + elif maximum_value < minimum: + msg = f"maximum value must be greater than {minimum}" + raise ParameterException(msg) return minimum_value, maximum_value @@ -168,6 +191,8 @@ def get_constrained_number_range( gt: T | None = None, ge: T | None = None, multiple_of: T | None = None, + max_digits: int | None = None, + decimal_places: int | None = None, ) -> tuple[T | None, T | None]: """Return the minimum and maximum values given a field_meta's constraints. @@ -178,11 +203,15 @@ def get_constrained_number_range( :param gt: Greater than value. :param ge: Greater than or equal value. :param multiple_of: Multiple of value. + :param decimal_places: Number of decimal places. + :param max_digits: Maximal number of digits. :returns: a tuple of optional minimum and maximum values. """ seed = t_type(random.random() * 10) - minimum, maximum = get_value_or_none(lt=lt, le=le, gt=gt, ge=ge, t_type=t_type) + minimum, maximum = get_value_or_none( + lt=lt, le=le, gt=gt, ge=ge, t_type=t_type, max_digits=max_digits, decimal_places=decimal_places + ) if minimum is not None and maximum is not None and maximum < minimum: msg = "maximum value must be greater than minimum value" @@ -329,19 +358,14 @@ def validate_max_digits( :returns: 'None' """ + check_for_deprecated_parameters("2.19.1", parameters=(("minimum", minimum),)) + if max_digits <= 0: msg = "max_digits must be greater than 0" raise ParameterException(msg) - if minimum is not None: - min_str = str(minimum).split(".")[1] if "." in str(minimum) else str(minimum) - - if max_digits <= len(min_str): - msg = "minimum is greater than max_digits" - raise ParameterException(msg) - - if decimal_places is not None and max_digits <= decimal_places: - msg = "max_digits must be greater than decimal places" + if decimal_places is not None and max_digits < decimal_places: + msg = "max_digits must be greater or equal than decimal places" raise ParameterException(msg) @@ -357,25 +381,18 @@ def handle_decimal_length( :param max_digits: Maximal number of digits. """ - string_number = str(generated_decimal) - sign = "-" if "-" in string_number else "+" - string_number = string_number.replace("-", "") - whole_numbers, decimals = string_number.split(".") - - if (max_digits is not None and decimal_places is not None and len(whole_numbers) + decimal_places > max_digits) or ( - (max_digits is None or decimal_places is None) and max_digits is not None - ): - max_decimals = max_digits - len(whole_numbers) - elif max_digits is not None: - max_decimals = decimal_places # type: ignore[assignment] - else: - max_decimals = cast("int", decimal_places) - - if max_decimals < 0: # pyright: ignore[reportOptionalOperand] - return Decimal(sign + whole_numbers[:max_decimals]) - - decimals = decimals[:max_decimals] - return Decimal(sign + whole_numbers + "." + decimals[:decimal_places]) + with localcontext() as ctx: + ctx.rounding = ROUND_DOWN + list_decimal = str(generated_decimal).strip("-0").split(".") + decimal_parts = 2 + if len(list_decimal) == decimal_parts: + whole, decimals = list_decimal + if decimal_places is not None and len(decimals) > decimal_places: + return round(generated_decimal, decimal_places) + if max_digits is not None and len(whole) + len(decimals) > max_digits: + max_decimals = max_digits - len(whole) + return round(generated_decimal, max_decimals) + return generated_decimal def handle_constrained_decimal( @@ -402,13 +419,14 @@ def handle_constrained_decimal( :returns: A decimal. """ - minimum, maximum = get_constrained_number_range( gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of, + max_digits=max_digits, + decimal_places=decimal_places, t_type=Decimal, random=random, ) @@ -424,11 +442,8 @@ def handle_constrained_decimal( method=create_random_decimal, ) - if max_digits is not None or decimal_places is not None: - return handle_decimal_length( - generated_decimal=generated_decimal, - max_digits=max_digits, - decimal_places=decimal_places, - ) - - return generated_decimal + return handle_decimal_length( + generated_decimal=generated_decimal, + max_digits=max_digits, + decimal_places=decimal_places, + ) diff --git a/tests/constraints/test_decimal_constraints.py b/tests/constraints/test_decimal_constraints.py index b6a76a3c..64dd62bb 100644 --- a/tests/constraints/test_decimal_constraints.py +++ b/tests/constraints/test_decimal_constraints.py @@ -1,6 +1,6 @@ from decimal import Decimal from random import Random -from typing import Optional, cast +from typing import cast import pytest from hypothesis import given @@ -32,12 +32,23 @@ def test_handle_constrained_decimal_without_constraints() -> None: assert isinstance(result, Decimal) -def test_handle_constrained_decimal_length_validation() -> None: - with pytest.raises(ParameterException): +@pytest.mark.parametrize( + ("msg", "ge", "le"), + ( + ("minimum value must be less than", Decimal("100.000"), Decimal()), + ("maximum value must be greater than", Decimal(), Decimal("-100.000")), + ), +) +def test_handle_constrained_decimal_length_validation(msg: str, ge: Decimal, le: Decimal) -> None: + with pytest.raises( + ParameterException, + match=msg, + ): handle_constrained_decimal( random=Random(), max_digits=2, - ge=Decimal("100.000"), + ge=ge, + le=le, ) @@ -48,9 +59,12 @@ def test_handle_constrained_decimal_handles_max_digits(max_digits: int) -> None: random=Random(), max_digits=max_digits, ) - assert len(result.as_tuple().digits) - abs(cast("int", result.as_tuple().exponent)) <= max_digits + assert len(result.as_tuple().digits) <= max_digits else: - with pytest.raises(ParameterException): + with pytest.raises( + ParameterException, + match="max_digits must be greater than 0", + ): handle_constrained_decimal( random=Random(), max_digits=max_digits, @@ -66,15 +80,23 @@ def test_handle_constrained_decimal_handles_decimal_places(decimal_places: int) assert abs(cast("int", result.as_tuple().exponent)) <= decimal_places -@given(integers(min_value=0, max_value=100), integers(min_value=1, max_value=100)) +@given(integers(min_value=0, max_value=100), integers(min_value=0, max_value=100)) def test_handle_constrained_decimal_handles_max_digits_and_decimal_places(max_digits: int, decimal_places: int) -> None: - if max_digits > 0 and max_digits > decimal_places: + if max_digits > 0 and max_digits >= decimal_places: result = handle_constrained_decimal( random=Random(), decimal_places=decimal_places, max_digits=max_digits, ) - assert len(result.as_tuple().digits) - abs(cast("int", result.as_tuple().exponent)) <= max_digits + non_fractionals = max_digits - decimal_places + list_value = str(result).strip("-").split(".") + if decimal_places: + left_digits, right_digits = list_value + assert len(right_digits) <= decimal_places + else: + left_digits = list_value[0] + assert len(left_digits) <= non_fractionals or int(left_digits) == non_fractionals + else: with pytest.raises(ParameterException): handle_constrained_decimal( @@ -368,16 +390,6 @@ class PersonFactory(ModelFactory): def test_handle_decimal_length() -> None: decimal = Decimal("999.9999999") - # here digits should determine decimal length - max_digits = 5 - decimal_places: Optional[int] = 5 - - result = handle_decimal_length(decimal, decimal_places, max_digits) - - assert isinstance(result, Decimal) - assert len(result.as_tuple().digits) == 5 - assert abs(cast("int", result.as_tuple().exponent)) == 2 - # here decimal places should determine max length max_digits = 10 decimal_places = 5 @@ -405,14 +417,6 @@ def test_handle_decimal_length() -> None: assert len(result.as_tuple().digits) == 8 assert abs(cast("int", result.as_tuple().exponent)) == 5 - # here max_decimals is below 0 - decimal = Decimal("99.99") - max_digits = 1 - result = handle_decimal_length(decimal, decimal_places, max_digits) - assert isinstance(result, Decimal) - assert len(result.as_tuple().digits) == 1 - assert cast("int", result.as_tuple().exponent) == 0 - def test_zero_to_one_range() -> None: class FractionExample(BaseModel): diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py index 5c8084ce..780cc17f 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py @@ -488,3 +488,37 @@ class ModelFactory(SQLAlchemyFactory[Model]): constrained_number: Decimal = instance.constrainted_number assert isinstance(constrained_number, Decimal) assert abs(len(constrained_number.as_tuple().digits) - abs(int(constrained_number.as_tuple().exponent))) <= 2 + + +@pytest.mark.parametrize( + "numeric", + ( + Numeric(), + Numeric(precision=4), + Numeric(precision=4, scale=0), + Numeric(precision=4, scale=2), + ), +) +def test_numeric_field(numeric: Numeric) -> None: + _registry = registry() + + class Base(metaclass=DeclarativeMeta): + __abstract__ = True + __allow_unmapped__ = True + + registry = _registry + metadata = _registry.metadata + + class NumericModel(Base): + __tablename__ = "numerics" + + id: Any = Column(Integer(), primary_key=True) + numeric_field: Any = Column(numeric, nullable=False) + + class NumericModelFactory(SQLAlchemyFactory[NumericModel]): ... + + result = NumericModelFactory.get_model_fields()[1] + assert result.annotation is Decimal + if constraints := result.constraints: + assert constraints.get("max_digits") == numeric.precision + assert constraints.get("decimal_places") == numeric.scale