Skip to content

Commit

Permalink
sum of iterable post hook (#251)
Browse files Browse the repository at this point in the history
* added in sum of iterable post hook
* linted
  • Loading branch information
ncilfone authored Apr 28, 2022
1 parent 733f987 commit a14fb43
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 5 deletions.
57 changes: 53 additions & 4 deletions spock/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sys
from argparse import _ArgumentGroup
from enum import EnumMeta
from math import isclose
from pathlib import Path
from time import localtime, strftime
from typing import Any, Dict, List, Tuple, Type, TypeVar, Union
Expand Down Expand Up @@ -54,15 +55,15 @@ def _get_callable_type():
_C = TypeVar("_C", bound=type)


def eq_len(val: List[Union[Tuple, List, None]], allow_optional: bool = True):
"""Checks that all values passed in the iterable are of the same length
def _filter_optional(val: List, allow_optional: bool = True):
"""Filters an iterable for None values if they are allowed
Args:
val: iterable to compare lengths
val: iterable of values that might contain None
allow_optional: allows the check to succeed if a given val in the iterable is None
Returns:
None
filtered list of values with None values removed
Raises:
_SpockValueError
Expand All @@ -76,6 +77,54 @@ def eq_len(val: List[Union[Tuple, List, None]], allow_optional: bool = True):
)
elif v is not None:
filtered_val.append(v)
return filtered_val


def sum_vals(
val: List[Union[float, int, None]],
sum_val: Union[float, int],
allow_optional: bool = True,
rel_tol: float = 1e-9,
abs_tol: float = 0.0,
):
"""Checks if an iterable of values sums within tolerance to a specified value
Args:
val: iterable of values to sum
sum_val: sum value to compare against
allow_optional: allows the check to succeed if a given val in the iterable is None
rel_tol: relative tolerance – it is the maximum allowed difference between a and b
abs_tol: the minimum absolute tolerance – useful for comparisons near zero
Returns:
None
Raises:
_SpockValueError
"""
filtered_val = _filter_optional(val, allow_optional)
if not isclose(sum(filtered_val), sum_val, rel_tol=rel_tol, abs_tol=abs_tol):
raise _SpockValueError(
f"Sum of iterable is `{sum(filtered_val)}` which is not equal to specified value `{sum_val}` within given tolerances"
)


def eq_len(val: List[Union[Tuple, List, None]], allow_optional: bool = True):
"""Checks that all values passed in the iterable are of the same length
Args:
val: iterable to compare lengths
allow_optional: allows the check to succeed if a given val in the iterable is None
Returns:
None
Raises:
_SpockValueError
"""
filtered_val = _filter_optional(val, allow_optional)
# just do a set comprehension -- iterables shouldn't be that long so pay the O(n) price
lens = {len(v) for v in filtered_val}
if len(lens) != 1:
Expand Down
54 changes: 53 additions & 1 deletion tests/base/test_post_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from spock import spock
from spock import SpockBuilder
from spock.utils import within, gt, ge, lt, le, eq_len
from spock.utils import within, gt, ge, lt, le, eq_len, sum_vals
from spock.exceptions import _SpockInstantiationError


Expand Down Expand Up @@ -130,8 +130,60 @@ def __post_hook__(self):
eq_len([self.val_1, self.val_2, self.val_3], allow_optional=True)


@spock
class SumNoneFailConfig:
val_1: float = 0.5
val_2: float = 0.5
val_3: Optional[float] = None

def __post_hook__(self):
sum_vals([self.val_1, self.val_2, self.val_3], sum_val=1.0, allow_optional=False)


@spock
class SumNoneNotEqualConfig:
val_1: float = 0.5
val_2: float = 0.5
val_3: Optional[float] = None

def __post_hook__(self):
sum_vals([self.val_1, self.val_2, self.val_3], sum_val=0.75)


class TestPostHooks:

def test_sum_none_fail_config(self, monkeypatch, tmp_path):
"""Test serialization/de-serialization"""
with monkeypatch.context() as m:
m.setattr(
sys,
"argv",
[""],
)
with pytest.raises(_SpockInstantiationError):
config = SpockBuilder(
SumNoneFailConfig,
desc="Test Builder",
)
config.generate()

def test_sum_not_equal_config(self, monkeypatch, tmp_path):
"""Test serialization/de-serialization"""
with monkeypatch.context() as m:
m.setattr(
sys,
"argv",
[""],
)
with pytest.raises(_SpockInstantiationError):
config = SpockBuilder(
SumNoneNotEqualConfig,
desc="Test Builder",
)
config.generate()



def test_eq_len_two_len_fail(self, monkeypatch, tmp_path):
"""Test serialization/de-serialization"""
with monkeypatch.context() as m:
Expand Down

0 comments on commit a14fb43

Please sign in to comment.