From 9ab13258b5cdf9e772b4d0956746cf07ea03716d Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 13 Aug 2024 09:45:43 -0400 Subject: [PATCH 1/3] shared default rt across tests --- model/src/test/test_forecast.py | 20 +++-------- model/src/test/test_latent_admissions.py | 20 +++-------- model/src/test/test_latent_infections.py | 20 ++--------- model/src/test/test_model_basic_renewal.py | 32 ++--------------- model/src/test/test_model_hosp_admissions.py | 37 ++------------------ model/src/test/test_predictive.py | 18 +++------- model/src/test/test_random_key.py | 20 +++-------- model/src/test/utils.py | 33 +++++++++++++++++ 8 files changed, 58 insertions(+), 142 deletions(-) create mode 100644 model/src/test/utils.py diff --git a/model/src/test/test_forecast.py b/model/src/test/test_forecast.py index 1e293b47..38bf353a 100644 --- a/model/src/test/test_forecast.py +++ b/model/src/test/test_forecast.py @@ -1,10 +1,11 @@ # numpydoc ignore=GL08 +from test.utils import get_default_rt + import jax.numpy as jnp import jax.random as jr import numpyro import numpyro.distributions as dist -import pyrenew.transformation as t from numpy.testing import assert_array_equal from pyrenew.deterministic import DeterministicPMF from pyrenew.latent import ( @@ -12,10 +13,9 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable +from pyrenew.metaclass import DistributionalRV from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation -from pyrenew.process import SimpleRandomWalkProcess def test_forecast(): @@ -30,19 +30,7 @@ def test_forecast(): ) latent_infections = Infections() observed_infections = PoissonObservation(name="poisson_rv") - rt = TransformedRandomVariable( - name="Rt_rv", - base_rv=SimpleRandomWalkProcess( - name="log_rt", - step_rv=DistributionalRV( - name="rw_step_rv", dist=dist.Normal(0, 0.025) - ), - init_rv=DistributionalRV( - name="init_log_rt", dist=dist.Normal(0, 0.2) - ), - ), - transforms=t.ExpTransform(), - ) + rt = get_default_rt() model = RtInfectionsRenewalModel( I0_rv=I0, diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index 73f41c17..8fe0d846 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -1,15 +1,15 @@ # -*- coding: utf-8 -*- # numpydoc ignore=GL08 +from test.utils import get_default_rt + import jax.numpy as jnp import numpy.testing as testing import numpyro import numpyro.distributions as dist -from pyrenew import transformation as t from pyrenew.deterministic import DeterministicPMF from pyrenew.latent import HospitalAdmissions, Infections -from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable -from pyrenew.process import SimpleRandomWalkProcess +from pyrenew.metaclass import DistributionalRV def test_admissions_sample(): @@ -20,19 +20,7 @@ def test_admissions_sample(): # Generating Rt and Infections to compute the hospital admissions - rt = TransformedRandomVariable( - name="Rt_rv", - base_rv=SimpleRandomWalkProcess( - name="log_rt", - step_rv=DistributionalRV( - name="rw_step_rv", dist=dist.Normal(0, 0.025) - ), - init_rv=DistributionalRV( - name="init_log_rt", dist=dist.Normal(0, 0.2) - ), - ), - transforms=t.ExpTransform(), - ) + rt = get_default_rt() with numpyro.handlers.seed(rng_seed=223): sim_rt = rt(n_steps=30)[0].value diff --git a/model/src/test/test_latent_infections.py b/model/src/test/test_latent_infections.py index fcfd3f99..ea19e4f4 100755 --- a/model/src/test/test_latent_infections.py +++ b/model/src/test/test_latent_infections.py @@ -1,15 +1,13 @@ # -*- coding: utf-8 -*- # numpydoc ignore=GL08 +from test.utils import get_default_rt + import jax.numpy as jnp import numpy.testing as testing import numpyro -import numpyro.distributions as dist -import pyrenew.transformation as t import pytest from pyrenew.latent import Infections -from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable -from pyrenew.process import SimpleRandomWalkProcess def test_infections_as_deterministic(): @@ -18,19 +16,7 @@ def test_infections_as_deterministic(): the same seed is used. """ - rt = TransformedRandomVariable( - "Rt_rv", - base_rv=SimpleRandomWalkProcess( - name="log_rt", - step_rv=DistributionalRV( - name="rw_step_rv", dist=dist.Normal(0, 0.025) - ), - init_rv=DistributionalRV( - name="init_log_rt", dist=dist.Normal(0, 0.2) - ), - ), - transforms=t.ExpTransform(), - ) + rt = get_default_rt() with numpyro.handlers.seed(rng_seed=223): sim_rt, *_ = rt(n_steps=30) diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index 34fce28b..df795192 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -2,13 +2,14 @@ # numpydoc ignore=GL08 +from test.utils import get_default_rt + import jax.numpy as jnp import jax.random as jr import numpy as np import numpyro import numpyro.distributions as dist import polars as pl -import pyrenew.transformation as t import pytest from pyrenew.deterministic import DeterministicPMF, NullObservation from pyrenew.latent import ( @@ -16,36 +17,9 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable +from pyrenew.metaclass import DistributionalRV from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation -from pyrenew.process import SimpleRandomWalkProcess - - -def get_default_rt(): - """ - Helper function to create a default Rt - RandomVariable for this testing session. - - Returns - ------- - TransformedRandomVariable : - A log-scale random walk with fixed - init value and step size priors - """ - return TransformedRandomVariable( - "Rt_rv", - base_rv=SimpleRandomWalkProcess( - name="log_rt", - step_rv=DistributionalRV( - name="rw_step_rv", dist=dist.Normal(0, 0.025) - ), - init_rv=DistributionalRV( - name="init_log_rt", dist=dist.Normal(0, 0.2) - ), - ), - transforms=t.ExpTransform(), - ) def test_model_basicrenewal_no_timepoints_or_observations(): diff --git a/model/src/test/test_model_hosp_admissions.py b/model/src/test/test_model_hosp_admissions.py index 4766d9e6..84c1f7f3 100644 --- a/model/src/test/test_model_hosp_admissions.py +++ b/model/src/test/test_model_hosp_admissions.py @@ -2,6 +2,8 @@ # numpydoc ignore=GL08 +from test.utils import get_default_rt + import jax.numpy as jnp import jax.random as jr import numpy as np @@ -9,7 +11,6 @@ import numpyro.distributions as dist import polars as pl import pytest -from pyrenew import transformation as t from pyrenew.deterministic import ( DeterministicPMF, DeterministicVariable, @@ -21,41 +22,9 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import ( - DistributionalRV, - RandomVariable, - SampledValue, - TransformedRandomVariable, -) +from pyrenew.metaclass import DistributionalRV, RandomVariable, SampledValue from pyrenew.model import HospitalAdmissionsModel from pyrenew.observation import PoissonObservation -from pyrenew.process import SimpleRandomWalkProcess - - -def get_default_rt(): - """ - Helper function to create a default Rt - RandomVariable for this testing session. - - Returns - ------- - TransformedRandomVariable : - A log-scale random walk with fixed - init value and step size priors - """ - return TransformedRandomVariable( - "Rt_rv", - base_rv=SimpleRandomWalkProcess( - name="log_rt", - step_rv=DistributionalRV( - name="rw_step_rv", dist=dist.Normal(0, 0.025) - ), - init_rv=DistributionalRV( - name="init_log_rt", dist=dist.Normal(0, 0.2) - ), - ), - transforms=t.ExpTransform(), - ) class UniformProbForTest(RandomVariable): # numpydoc ignore=GL08 diff --git a/model/src/test/test_predictive.py b/model/src/test/test_predictive.py index 1a15a3c2..5a8ae08d 100644 --- a/model/src/test/test_predictive.py +++ b/model/src/test/test_predictive.py @@ -4,9 +4,10 @@ Ensures that posterior predictive samples are not generated when no posterior samples are available. """ +from test.utils import get_default_rt + import jax.numpy as jnp import numpyro.distributions as dist -import pyrenew.transformation as t import pytest from pyrenew.deterministic import DeterministicPMF from pyrenew.latent import ( @@ -14,10 +15,9 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable +from pyrenew.metaclass import DistributionalRV from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation -from pyrenew.process import SimpleRandomWalkProcess pmf_array = jnp.array([0.25, 0.1, 0.2, 0.45]) gen_int = DeterministicPMF(name="gen_int", value=pmf_array) @@ -29,17 +29,7 @@ ) latent_infections = Infections() observed_infections = PoissonObservation("poisson_rv") -rt = TransformedRandomVariable( - "Rt_rv", - base_rv=SimpleRandomWalkProcess( - name="log_rt", - step_rv=DistributionalRV( - name="rw_step_rv", dist=dist.Normal(0, 0.025) - ), - init_rv=DistributionalRV(name="init_log_rt", dist=dist.Normal(0, 0.2)), - ), - transforms=t.ExpTransform(), -) +rt = get_default_rt() model = RtInfectionsRenewalModel( I0_rv=I0, diff --git a/model/src/test/test_random_key.py b/model/src/test/test_random_key.py index 4bcc644f..b7d0bf43 100644 --- a/model/src/test/test_random_key.py +++ b/model/src/test/test_random_key.py @@ -5,11 +5,12 @@ with different random keys behave appropriately. """ +from test.utils import get_default_rt + import jax.numpy as jnp import jax.random as jr import numpyro import numpyro.distributions as dist -import pyrenew.transformation as t from numpy.testing import assert_array_equal, assert_raises from pyrenew.deterministic import DeterministicPMF from pyrenew.latent import ( @@ -17,10 +18,9 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable +from pyrenew.metaclass import DistributionalRV from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation -from pyrenew.process import SimpleRandomWalkProcess def create_test_model(): # numpydoc ignore=GL08 @@ -34,19 +34,7 @@ def create_test_model(): # numpydoc ignore=GL08 ) latent_infections = Infections() observed_infections = PoissonObservation("poisson_rv") - rt = TransformedRandomVariable( - "Rt_rv", - base_rv=SimpleRandomWalkProcess( - name="log_rt", - step_rv=DistributionalRV( - name="rw_step_rv", dist=dist.Normal(0, 0.025) - ), - init_rv=DistributionalRV( - name="init_log_rt", dist=dist.Normal(0, 0.2) - ), - ), - transforms=t.ExpTransform(), - ) + rt = get_default_rt() model = RtInfectionsRenewalModel( I0_rv=I0, gen_int_rv=gen_int, diff --git a/model/src/test/utils.py b/model/src/test/utils.py new file mode 100644 index 00000000..e6d66d94 --- /dev/null +++ b/model/src/test/utils.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- + +""" +test utilities +""" + +import numpyro.distributions as dist +import pyrenew.transformation as t + +from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable +from pyrenew.process import SimpleRandomWalkProcess + + +def get_default_rt(): + """ + Helper function to create a default Rt + RandomVariable for testing. + + Returns + ------- + TransformedRandomVariable : + A log-scale random walk with fixed + init value and step size priors + """ + return TransformedRandomVariable( + "Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV(name="rw_step_rv", dist=dist.Normal(0, 0.025)), + init_rv=DistributionalRV(name="init_log_rt", dist=dist.Normal(0, 0.2)), + ), + transforms=t.ExpTransform(), + ) From b20c03e060b9fefc02b7ecc33c0602932365875a Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 13 Aug 2024 10:03:49 -0400 Subject: [PATCH 2/3] update formatting and deptry ignore rule --- .github/workflows/deptry.yml | 2 +- model/src/test/utils.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/deptry.yml b/.github/workflows/deptry.yml index 071281df..d8efe991 100644 --- a/.github/workflows/deptry.yml +++ b/.github/workflows/deptry.yml @@ -29,4 +29,4 @@ jobs: - name: run deptry run: | - poetry run deptry . --per-rule-ignores "DEP001=pyrenew,DEP001=pytest,DEP003=pytest" + poetry run deptry . --per-rule-ignores "DEP001=pyrenew,DEP001=pytest,DEP003=pytest,DEP001=test" diff --git a/model/src/test/utils.py b/model/src/test/utils.py index e6d66d94..5c22f5d5 100644 --- a/model/src/test/utils.py +++ b/model/src/test/utils.py @@ -6,7 +6,6 @@ import numpyro.distributions as dist import pyrenew.transformation as t - from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable from pyrenew.process import SimpleRandomWalkProcess @@ -26,8 +25,12 @@ def get_default_rt(): "Rt_rv", base_rv=SimpleRandomWalkProcess( name="log_rt", - step_rv=DistributionalRV(name="rw_step_rv", dist=dist.Normal(0, 0.025)), - init_rv=DistributionalRV(name="init_log_rt", dist=dist.Normal(0, 0.2)), + step_rv=DistributionalRV( + name="rw_step_rv", dist=dist.Normal(0, 0.025) + ), + init_rv=DistributionalRV( + name="init_log_rt", dist=dist.Normal(0, 0.2) + ), ), transforms=t.ExpTransform(), ) From 875e28392516bc6ba06fb2de210368a4dfad68fe Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 13 Aug 2024 12:50:57 -0400 Subject: [PATCH 3/3] name change and include arg for name --- model/src/test/test_forecast.py | 4 ++-- model/src/test/test_latent_admissions.py | 4 ++-- model/src/test/test_latent_infections.py | 4 ++-- model/src/test/test_model_basic_renewal.py | 12 ++++++------ model/src/test/test_model_hosp_admissions.py | 14 +++++++------- model/src/test/test_predictive.py | 4 ++-- model/src/test/test_random_key.py | 4 ++-- model/src/test/utils.py | 10 ++++++++-- 8 files changed, 31 insertions(+), 25 deletions(-) diff --git a/model/src/test/test_forecast.py b/model/src/test/test_forecast.py index 38bf353a..3feaf373 100644 --- a/model/src/test/test_forecast.py +++ b/model/src/test/test_forecast.py @@ -1,6 +1,6 @@ # numpydoc ignore=GL08 -from test.utils import get_default_rt +from test.utils import simple_rt import jax.numpy as jnp import jax.random as jr @@ -30,7 +30,7 @@ def test_forecast(): ) latent_infections = Infections() observed_infections = PoissonObservation(name="poisson_rv") - rt = get_default_rt() + rt = simple_rt() model = RtInfectionsRenewalModel( I0_rv=I0, diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index 8fe0d846..7648d70f 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # numpydoc ignore=GL08 -from test.utils import get_default_rt +from test.utils import simple_rt import jax.numpy as jnp import numpy.testing as testing @@ -20,7 +20,7 @@ def test_admissions_sample(): # Generating Rt and Infections to compute the hospital admissions - rt = get_default_rt() + rt = simple_rt() with numpyro.handlers.seed(rng_seed=223): sim_rt = rt(n_steps=30)[0].value diff --git a/model/src/test/test_latent_infections.py b/model/src/test/test_latent_infections.py index ea19e4f4..7a83d1c0 100755 --- a/model/src/test/test_latent_infections.py +++ b/model/src/test/test_latent_infections.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # numpydoc ignore=GL08 -from test.utils import get_default_rt +from test.utils import simple_rt import jax.numpy as jnp import numpy.testing as testing @@ -16,7 +16,7 @@ def test_infections_as_deterministic(): the same seed is used. """ - rt = get_default_rt() + rt = simple_rt() with numpyro.handlers.seed(rng_seed=223): sim_rt, *_ = rt(n_steps=30) diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index df795192..9b5bdf46 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -2,7 +2,7 @@ # numpydoc ignore=GL08 -from test.utils import get_default_rt +from test.utils import simple_rt import jax.numpy as jnp import jax.random as jr @@ -39,7 +39,7 @@ def test_model_basicrenewal_no_timepoints_or_observations(): observed_infections = PoissonObservation("poisson_rv") - rt = get_default_rt() + rt = simple_rt() model1 = RtInfectionsRenewalModel( I0_rv=I0, @@ -70,7 +70,7 @@ def test_model_basicrenewal_both_timepoints_and_observations(): observed_infections = PoissonObservation("possion_rv") - rt = get_default_rt() + rt = simple_rt() model1 = RtInfectionsRenewalModel( I0_rv=I0, @@ -111,7 +111,7 @@ def test_model_basicrenewal_no_obs_model(): latent_infections = Infections() - rt = get_default_rt() + rt = simple_rt() model0 = RtInfectionsRenewalModel( gen_int_rv=gen_int, @@ -184,7 +184,7 @@ def test_model_basicrenewal_with_obs_model(): observed_infections = PoissonObservation("poisson_rv") - rt = get_default_rt() + rt = simple_rt() model1 = RtInfectionsRenewalModel( I0_rv=I0, @@ -233,7 +233,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 observed_infections = PoissonObservation("poisson_rv") - rt = get_default_rt() + rt = simple_rt() model1 = RtInfectionsRenewalModel( I0_rv=I0, diff --git a/model/src/test/test_model_hosp_admissions.py b/model/src/test/test_model_hosp_admissions.py index 84c1f7f3..7e7e129b 100644 --- a/model/src/test/test_model_hosp_admissions.py +++ b/model/src/test/test_model_hosp_admissions.py @@ -2,7 +2,7 @@ # numpydoc ignore=GL08 -from test.utils import get_default_rt +from test.utils import simple_rt import jax.numpy as jnp import jax.random as jr @@ -60,7 +60,7 @@ def test_model_hosp_no_timepoints_or_observations(): I0 = DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)) latent_infections = Infections() - Rt_process = get_default_rt() + Rt_process = simple_rt() observed_admissions = PoissonObservation("poisson_rv") @@ -125,7 +125,7 @@ def test_model_hosp_both_timepoints_and_observations(): I0 = DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)) latent_infections = Infections() - Rt_process = get_default_rt() + Rt_process = simple_rt() observed_admissions = PoissonObservation("poisson_rv") @@ -198,7 +198,7 @@ def test_model_hosp_no_obs_model(): ) latent_infections = Infections() - Rt_process = get_default_rt() + Rt_process = simple_rt() inf_hosp = DeterministicPMF( name="inf_hosp", @@ -308,7 +308,7 @@ def test_model_hosp_with_obs_model(): ) latent_infections = Infections() - Rt_process = get_default_rt() + Rt_process = simple_rt() observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( @@ -395,7 +395,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): ) latent_infections = Infections() - Rt_process = get_default_rt() + Rt_process = simple_rt() observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( @@ -494,7 +494,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): ) latent_infections = Infections() - Rt_process = get_default_rt() + Rt_process = simple_rt() observed_admissions = PoissonObservation("poisson_rv") diff --git a/model/src/test/test_predictive.py b/model/src/test/test_predictive.py index 5a8ae08d..c848ccda 100644 --- a/model/src/test/test_predictive.py +++ b/model/src/test/test_predictive.py @@ -4,7 +4,7 @@ Ensures that posterior predictive samples are not generated when no posterior samples are available. """ -from test.utils import get_default_rt +from test.utils import simple_rt import jax.numpy as jnp import numpyro.distributions as dist @@ -29,7 +29,7 @@ ) latent_infections = Infections() observed_infections = PoissonObservation("poisson_rv") -rt = get_default_rt() +rt = simple_rt() model = RtInfectionsRenewalModel( I0_rv=I0, diff --git a/model/src/test/test_random_key.py b/model/src/test/test_random_key.py index b7d0bf43..247555ca 100644 --- a/model/src/test/test_random_key.py +++ b/model/src/test/test_random_key.py @@ -5,7 +5,7 @@ with different random keys behave appropriately. """ -from test.utils import get_default_rt +from test.utils import simple_rt import jax.numpy as jnp import jax.random as jr @@ -34,7 +34,7 @@ def create_test_model(): # numpydoc ignore=GL08 ) latent_infections = Infections() observed_infections = PoissonObservation("poisson_rv") - rt = get_default_rt() + rt = simple_rt() model = RtInfectionsRenewalModel( I0_rv=I0, gen_int_rv=gen_int, diff --git a/model/src/test/utils.py b/model/src/test/utils.py index 5c22f5d5..4bb05a52 100644 --- a/model/src/test/utils.py +++ b/model/src/test/utils.py @@ -10,11 +10,17 @@ from pyrenew.process import SimpleRandomWalkProcess -def get_default_rt(): +def simple_rt(arg_name: str = "Rt_rv"): """ Helper function to create a default Rt RandomVariable for testing. + Parameters + ----------- + arg_name : str + Name assigned to the randonvariable. + If None, then defaults to "Rt_rv" + Returns ------- TransformedRandomVariable : @@ -22,7 +28,7 @@ def get_default_rt(): init value and step size priors """ return TransformedRandomVariable( - "Rt_rv", + arg_name, base_rv=SimpleRandomWalkProcess( name="log_rt", step_rv=DistributionalRV(