diff --git a/model/src/pyrenew/latent/infection_initialization_process.py b/model/src/pyrenew/latent/infection_initialization_process.py index 8e8c62e2..5ec82269 100644 --- a/model/src/pyrenew/latent/infection_initialization_process.py +++ b/model/src/pyrenew/latent/infection_initialization_process.py @@ -4,7 +4,7 @@ from pyrenew.latent.infection_initialization_method import ( InfectionInitializationMethod, ) -from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.metaclass import RandomVariable, SampledValue, _assert_type class InfectionInitializationProcess(RandomVariable): @@ -74,18 +74,12 @@ def validate( ------- None """ - if not isinstance(I_pre_init_rv, RandomVariable): - raise TypeError( - "I_pre_init_rv must be an instance of RandomVariable" - f"Got {type(I_pre_init_rv)}" - ) - if not isinstance( - infection_init_method, InfectionInitializationMethod - ): - raise TypeError( - "infection_init_method must be an instance of InfectionInitializationMethod" - f"Got {type(infection_init_method)}" - ) + _assert_type("I_pre_init_rv", I_pre_init_rv, RandomVariable) + _assert_type( + "infection_init_method", + infection_init_method, + InfectionInitializationMethod, + ) def sample(self) -> tuple: """Sample the Infection Initialization Process. diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index 06df3be5..b4ac6529 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -21,6 +21,36 @@ from pyrenew.transformation import Transform +def _assert_type(arg_name: str, value, expected_type) -> None: + """ + Matches TypeError arising during validation + + Parameters + ---------- + arg_name : str + Name of the argument + value : object + The object to be validated + expected_type : type + The expected object type + + Raises + ------- + TypeError + If `value` is not an instance of `expected_type`. + + Returns + ------- + None + """ + + if not isinstance(value, expected_type): + raise TypeError( + f"{arg_name} must be an instance of {expected_type}. " + f"Got {type(value)}" + ) + + def _assert_sample_and_rtype( rp: "RandomVariable", skip_if_none: bool = True ) -> None: diff --git a/model/src/test/test_assert_type.py b/model/src/test/test_assert_type.py new file mode 100644 index 00000000..ea66ddff --- /dev/null +++ b/model/src/test/test_assert_type.py @@ -0,0 +1,37 @@ +# numpydoc ignore=GL08 + +import numpyro.distributions as dist +import pytest +from pyrenew.metaclass import DistributionalRV, RandomVariable, _assert_type + + +def test_valid_assertion_types(): + """ + Test valid assertion types in _assert_type. + """ + + values = [ + 5, + "Hello", + (1,), + DistributionalRV(name="rv", dist=dist.Beta(1, 1)), + ] + arg_names = ["input_int", "input_string", "input_tuple", "input_rv"] + input_types = [int, str, tuple, RandomVariable] + + for arg, value, input in zip(arg_names, values, input_types): + _assert_type(arg, value, input) + + +def test_invalid_assertion_types(): + """ + Test invalid assertion types in _assert_type. + """ + + values = [None] * 4 + arg_names = ["input_int", "input_string", "input_tuple", "input_rv"] + input_types = [int, str, tuple, RandomVariable] + + for arg, value, input in zip(arg_names, values, input_types): + with pytest.raises(TypeError): + _assert_type(arg, value, input)