Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create and use default rt across tests #372

Merged
merged 5 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/deptry.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ jobs:

- name: run deptry
run: |
poetry run deptry . --per-rule-ignores "DEP001=pyrenew,DEP001=pytest,DEP003=pytest"
poetry run deptry . --per-rule-ignores "DEP001=pyrenew,DEP001=pytest,DEP003=pytest,DEP001=test"
20 changes: 4 additions & 16 deletions model/src/test/test_forecast.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
# numpydoc ignore=GL08

from test.utils import get_default_rt

import jax.numpy as jnp
import jax.random as jr
import numpyro
import numpyro.distributions as dist
import pyrenew.transformation as t
from numpy.testing import assert_array_equal
from pyrenew.deterministic import DeterministicPMF
from pyrenew.latent import (
InfectionInitializationProcess,
Infections,
InitializeInfectionsZeroPad,
)
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.metaclass import DistributionalRV
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.observation import PoissonObservation
from pyrenew.process import SimpleRandomWalkProcess


def test_forecast():
Expand All @@ -30,19 +30,7 @@ def test_forecast():
)
latent_infections = Infections()
observed_infections = PoissonObservation(name="poisson_rv")
rt = TransformedRandomVariable(
name="Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)
rt = get_default_rt()

model = RtInfectionsRenewalModel(
I0_rv=I0,
Expand Down
20 changes: 4 additions & 16 deletions model/src/test/test_latent_admissions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
# numpydoc ignore=GL08

from test.utils import get_default_rt

import jax.numpy as jnp
import numpy.testing as testing
import numpyro
import numpyro.distributions as dist
from pyrenew import transformation as t
from pyrenew.deterministic import DeterministicPMF
from pyrenew.latent import HospitalAdmissions, Infections
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.process import SimpleRandomWalkProcess
from pyrenew.metaclass import DistributionalRV


def test_admissions_sample():
Expand All @@ -20,19 +20,7 @@ def test_admissions_sample():

# Generating Rt and Infections to compute the hospital admissions

rt = TransformedRandomVariable(
name="Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)
rt = get_default_rt()

with numpyro.handlers.seed(rng_seed=223):
sim_rt = rt(n_steps=30)[0].value
Expand Down
20 changes: 3 additions & 17 deletions model/src/test/test_latent_infections.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# -*- coding: utf-8 -*-
# numpydoc ignore=GL08

from test.utils import get_default_rt

import jax.numpy as jnp
import numpy.testing as testing
import numpyro
import numpyro.distributions as dist
import pyrenew.transformation as t
import pytest
from pyrenew.latent import Infections
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.process import SimpleRandomWalkProcess


def test_infections_as_deterministic():
Expand All @@ -18,19 +16,7 @@ def test_infections_as_deterministic():
the same seed is used.
"""

rt = TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)
rt = get_default_rt()

with numpyro.handlers.seed(rng_seed=223):
sim_rt, *_ = rt(n_steps=30)
Expand Down
32 changes: 3 additions & 29 deletions model/src/test/test_model_basic_renewal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,24 @@
# numpydoc ignore=GL08


from test.utils import get_default_rt

import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpyro
import numpyro.distributions as dist
import polars as pl
import pyrenew.transformation as t
import pytest
from pyrenew.deterministic import DeterministicPMF, NullObservation
from pyrenew.latent import (
InfectionInitializationProcess,
Infections,
InitializeInfectionsZeroPad,
)
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.metaclass import DistributionalRV
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.observation import PoissonObservation
from pyrenew.process import SimpleRandomWalkProcess


def get_default_rt():
"""
Helper function to create a default Rt
RandomVariable for this testing session.

Returns
-------
TransformedRandomVariable :
A log-scale random walk with fixed
init value and step size priors
"""
return TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)


def test_model_basicrenewal_no_timepoints_or_observations():
Expand Down
37 changes: 3 additions & 34 deletions model/src/test/test_model_hosp_admissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
# numpydoc ignore=GL08


from test.utils import get_default_rt

import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpyro
import numpyro.distributions as dist
import polars as pl
import pytest
from pyrenew import transformation as t
from pyrenew.deterministic import (
DeterministicPMF,
DeterministicVariable,
Expand All @@ -21,41 +22,9 @@
Infections,
InitializeInfectionsZeroPad,
)
from pyrenew.metaclass import (
DistributionalRV,
RandomVariable,
SampledValue,
TransformedRandomVariable,
)
from pyrenew.metaclass import DistributionalRV, RandomVariable, SampledValue
from pyrenew.model import HospitalAdmissionsModel
from pyrenew.observation import PoissonObservation
from pyrenew.process import SimpleRandomWalkProcess


def get_default_rt():
"""
Helper function to create a default Rt
RandomVariable for this testing session.

Returns
-------
TransformedRandomVariable :
A log-scale random walk with fixed
init value and step size priors
"""
return TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)


class UniformProbForTest(RandomVariable): # numpydoc ignore=GL08
Expand Down
18 changes: 4 additions & 14 deletions model/src/test/test_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
Ensures that posterior predictive samples are not generated when no posterior samples are available.
"""

from test.utils import get_default_rt

import jax.numpy as jnp
import numpyro.distributions as dist
import pyrenew.transformation as t
import pytest
from pyrenew.deterministic import DeterministicPMF
from pyrenew.latent import (
InfectionInitializationProcess,
Infections,
InitializeInfectionsZeroPad,
)
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.metaclass import DistributionalRV
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.observation import PoissonObservation
from pyrenew.process import SimpleRandomWalkProcess

pmf_array = jnp.array([0.25, 0.1, 0.2, 0.45])
gen_int = DeterministicPMF(name="gen_int", value=pmf_array)
Expand All @@ -29,17 +29,7 @@
)
latent_infections = Infections()
observed_infections = PoissonObservation("poisson_rv")
rt = TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(name="init_log_rt", dist=dist.Normal(0, 0.2)),
),
transforms=t.ExpTransform(),
)
rt = get_default_rt()

model = RtInfectionsRenewalModel(
I0_rv=I0,
Expand Down
20 changes: 4 additions & 16 deletions model/src/test/test_random_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@
with different random keys behave appropriately.
"""

from test.utils import get_default_rt

import jax.numpy as jnp
import jax.random as jr
import numpyro
import numpyro.distributions as dist
import pyrenew.transformation as t
from numpy.testing import assert_array_equal, assert_raises
from pyrenew.deterministic import DeterministicPMF
from pyrenew.latent import (
InfectionInitializationProcess,
Infections,
InitializeInfectionsZeroPad,
)
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.metaclass import DistributionalRV
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.observation import PoissonObservation
from pyrenew.process import SimpleRandomWalkProcess


def create_test_model(): # numpydoc ignore=GL08
Expand All @@ -34,19 +34,7 @@ def create_test_model(): # numpydoc ignore=GL08
)
latent_infections = Infections()
observed_infections = PoissonObservation("poisson_rv")
rt = TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)
rt = get_default_rt()
model = RtInfectionsRenewalModel(
I0_rv=I0,
gen_int_rv=gen_int,
Expand Down
36 changes: 36 additions & 0 deletions model/src/test/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-

"""
test utilities
"""

import numpyro.distributions as dist
import pyrenew.transformation as t
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.process import SimpleRandomWalkProcess


def get_default_rt():
"""
Helper function to create a default Rt
RandomVariable for testing.

Returns
-------
TransformedRandomVariable :
A log-scale random walk with fixed
init value and step size priors
"""
return TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", dist=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
name="init_log_rt", dist=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)